From 9e0aada6c7a8df85c8d9d445e27b84ef2be264c3 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Tue, 6 May 2025 16:58:24 +0530 Subject: [PATCH 01/16] we make progress --- .gitignore | 7 +++++ code/config.py | 20 ++++++------ code/dataset.py | 18 +++++------ code/eeg_ldm.py | 2 +- code/gen_eval_eeg.py | 8 ++--- code/stageA1_eeg_pretrain.py | 4 +-- env114.yaml | 31 +++++++++++++++++++ .../generation/06-05-2025-16-14-36/README.md | 1 + monREADME.md | 22 +++++++++++++ 9 files changed, 87 insertions(+), 26 deletions(-) create mode 100644 .gitignore create mode 100644 env114.yaml create mode 100644 exps/results/generation/06-05-2025-16-14-36/README.md create mode 100644 monREADME.md diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..df115b1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.DS_Store + +__pycache__/ + +*.egg-info/ +datasets/* +pretrains/models/v1-5-pruned.ckpt diff --git a/code/config.py b/code/config.py index 493c884..93d5cf8 100644 --- a/code/config.py +++ b/code/config.py @@ -31,8 +31,8 @@ def __init__(self): self.mlp_ratio = 1.0 # Project setting - self.root_path = '../dreamdiffusion/' - self.output_path = '../dreamdiffusion/exps/' + self.root_path = './' + self.output_path = './exps/' self.seed = 2022 self.roi = 'VC' self.aug_times = 1 @@ -54,15 +54,15 @@ class Config_EEG_finetune(Config_MBM_finetune): def __init__(self): # Project setting - self.root_path = '../dreamdiffusion/' + self.root_path = './' # self.root_path = '.' - self.output_path = '../dreamdiffusion/exps/' + self.output_path = './exps/' self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth') self.splits_path = os.path.join(self.root_path, 'datasets/block_splits_by_image_all.pth') self.dataset = 'EEG' - self.pretrain_mbm_path = '../dreamdiffusion/pretrains/eeg_pretrain/checkpoint.pth' + self.pretrain_mbm_path = './pretrains/eeg_pretrain/checkpoint.pth' self.include_nonavg_test = True @@ -89,8 +89,8 @@ class Config_Generative_Model: def __init__(self): # project parameters self.seed = 2022 - self.root_path = '../dreamdiffusion/' - self.output_path = '../dreamdiffusion/exps/' + self.root_path = './' + self.output_path = './exps/' self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth') self.splits_path = os.path.join(self.root_path, 'datasets/block_splits_by_image_single.pth') @@ -139,8 +139,8 @@ class Config_Cls_Model: def __init__(self): # project parameters self.seed = 2022 - self.root_path = '../dreamdiffusion/' - self.output_path = '../dreamdiffusion/exps/' + self.root_path = './' + self.output_path = './exps/' # self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth') self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_14_70_std.pth') @@ -181,4 +181,4 @@ def __init__(self): self.HW = None # resume check util self.model_meta = None - self.checkpoint_path = None \ No newline at end of file + self.checkpoint_path = None diff --git a/code/dataset.py b/code/dataset.py index 3c8957b..56a6001 100644 --- a/code/dataset.py +++ b/code/dataset.py @@ -106,7 +106,7 @@ def is_npy_ext(fname: Union[str, Path]) -> bool: return f'{ext}' == 'npy'# type: ignore class eeg_pretrain_dataset(Dataset): - def __init__(self, path='../dreamdiffusion/datasets/mne_data/', roi='VC', patch_size=16, transform=identity, aug_times=2, + def __init__(self, path='./datasets/mne_data/', roi='VC', patch_size=16, transform=identity, aug_times=2, num_sub_limit=None, include_kam=False, include_hcp=True): super(eeg_pretrain_dataset, self).__init__() data = [] @@ -240,7 +240,7 @@ class EEGDataset_r(Dataset): # Constructor def __init__(self, eeg_signals_path, image_transform=identity): - self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/' + self.imagenet = './datasets/imageNet_2012/train/' self.image_transform = image_transform self.num_voxels = 440 self.data_len = 512 @@ -386,12 +386,12 @@ def __getitem__(self, i): return self.dataset[self.split_idx[i]] -def create_EEG_dataset(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std.pth', - splits_path = '../dreamdiffusion/datasets/block_splits_by_image_single.pth', - # splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth', +def create_EEG_dataset(eeg_signals_path='./datasets/eeg_5_95_std.pth', + splits_path = './datasets/block_splits_by_image_single.pth', + # splits_path = './datasets/block_splits_by_image_all.pth', image_transform=identity, subject = 0): # if subject == 0: - # splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth' + # splits_path = './datasets/block_splits_by_image_all.pth' if isinstance(image_transform, list): dataset_train = EEGDataset(eeg_signals_path, image_transform[0], subject ) dataset_test = EEGDataset(eeg_signals_path, image_transform[1], subject) @@ -405,9 +405,9 @@ def create_EEG_dataset(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std -def create_EEG_dataset_r(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std.pth', - # splits_path = '../dreamdiffusion/datasets/block_splits_by_image_single.pth', - splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth', +def create_EEG_dataset_r(eeg_signals_path='./datasets/eeg_5_95_std.pth', + # splits_path = './datasets/block_splits_by_image_single.pth', + splits_path = './datasets/block_splits_by_image_all.pth', image_transform=identity): if isinstance(image_transform, list): dataset_train = EEGDataset_r(eeg_signals_path, image_transform[0]) diff --git a/code/eeg_ldm.py b/code/eeg_ldm.py index fddd98a..a780174 100644 --- a/code/eeg_ldm.py +++ b/code/eeg_ldm.py @@ -178,7 +178,7 @@ def get_args_parser(): parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False) # project parameters parser.add_argument('--seed', type=int) - parser.add_argument('--root_path', type=str, default = '../dreamdiffusion/') + parser.add_argument('--root_path', type=str, default = './') parser.add_argument('--pretrain_mbm_path', type=str) parser.add_argument('--checkpoint_path', type=str) parser.add_argument('--crop_ratio', type=float) diff --git a/code/gen_eval_eeg.py b/code/gen_eval_eeg.py index a079f51..37e0454 100644 --- a/code/gen_eval_eeg.py +++ b/code/gen_eval_eeg.py @@ -51,7 +51,7 @@ def __call__(self, img): def get_args_parser(): parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False) # project parameters - parser.add_argument('--root', type=str, default='../dreamdiffusion/') + parser.add_argument('--root', type=str, default='./') parser.add_argument('--dataset', type=str, default='GOD') parser.add_argument('--model_path', type=str) @@ -68,8 +68,8 @@ def get_args_parser(): config = sd['config'] # update paths config.root_path = root - config.pretrain_mbm_path = '../dreamdiffusion/results/eeg_pretrain/19-02-2023-08-48-17/checkpoints/checkpoint.pth' - config.pretrain_gm_path = '../dreamdiffusion/pretrains/' + config.pretrain_mbm_path = './results/eeg_pretrain/19-02-2023-08-48-17/checkpoints/checkpoint.pth' + config.pretrain_gm_path = './pretrains/' print(config.__dict__) output_path = os.path.join(config.root_path, 'results', 'eval', @@ -91,7 +91,7 @@ def get_args_parser(): ]) - splits_path = "../dreamdiffusion/datasets/block_splits_by_image_single.pth" + splits_path = "./datasets/block_splits_by_image_single.pth" dataset_train, dataset_test = create_EEG_dataset(eeg_signals_path = config.eeg_signals_path, splits_path = splits_path, image_transform=[img_transform_train, img_transform_test], subject = 4) num_voxels = dataset_test.dataset.data_len diff --git a/code/stageA1_eeg_pretrain.py b/code/stageA1_eeg_pretrain.py index 7c1d8e6..ada1264 100644 --- a/code/stageA1_eeg_pretrain.py +++ b/code/stageA1_eeg_pretrain.py @@ -121,7 +121,7 @@ def main(config): np.random.seed(config.seed) # create dataset and dataloader - dataset_pretrain = eeg_pretrain_dataset(path='../dreamdiffusion/datasets/mne_data/', roi=config.roi, patch_size=config.patch_size, + dataset_pretrain = eeg_pretrain_dataset(path='./datasets/mne_data/', roi=config.roi, patch_size=config.patch_size, transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, include_kam=config.include_kam, include_hcp=config.include_hcp) @@ -287,4 +287,4 @@ def update_config(args, config): config = Config_MBM_EEG() config = update_config(args, config) main(config) - \ No newline at end of file + diff --git a/env114.yaml b/env114.yaml new file mode 100644 index 0000000..bb749a1 --- /dev/null +++ b/env114.yaml @@ -0,0 +1,31 @@ +name: dreamdiffusion-cu114 +channels: + - conda-forge + - pytorch + - defaults + +dependencies: + - python=3.8.5 + - pip>=20.3 + - pytorch-lightning=1.6.5 # installed via conda-forge + - pip: + - numpy + - matplotlib + - natsort + - kornia + - omegaconf==2.1.1 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - --extra-index-url https://download.pytorch.org/whl/cu114 + - torch==1.12.1 + - torchvision==0.13.1 + - Pillow==9.0.1 + - timm==0.5.4 + - tqdm==4.64.0 + - wandb==0.12.21 + - torchmetrics==0.9.2 + - scikit-image + - lpips==0.1.4 + - transformers==4.20.1 + - -e ./code + diff --git a/exps/results/generation/06-05-2025-16-14-36/README.md b/exps/results/generation/06-05-2025-16-14-36/README.md new file mode 100644 index 0000000..4f92441 --- /dev/null +++ b/exps/results/generation/06-05-2025-16-14-36/README.md @@ -0,0 +1 @@ +{'seed': 2022, 'root_path': './', 'output_path': './exps/results/generation/06-05-2025-16-14-36', 'eeg_signals_path': './datasets/eeg_5_95_std.pth', 'splits_path': './datasets/block_splits_by_image_single.pth', 'roi': 'VC', 'patch_size': 4, 'embed_dim': 1024, 'depth': 24, 'num_heads': 16, 'mlp_ratio': 1.0, 'pretrain_gm_path': './pretrains', 'dataset': 'EEG', 'pretrain_mbm_path': './pretrains/eeg_pretrain/checkpoint.pth', 'img_size': 512, 'batch_size': 4, 'lr': 5.3e-05, 'num_epoch': 300, 'precision': 32, 'accumulate_grad': 1, 'crop_ratio': 0.2, 'global_pool': False, 'use_time_cond': True, 'clip_tune': True, 'cls_tune': False, 'subject': 4, 'eval_avg': True, 'num_samples': 5, 'ddim_steps': 250, 'HW': None, 'model_meta': None, 'checkpoint_path': None} diff --git a/monREADME.md b/monREADME.md new file mode 100644 index 0000000..2c84462 --- /dev/null +++ b/monREADME.md @@ -0,0 +1,22 @@ +# Dataset + +After going through the dataset: + +In eeg-5-95-std.pth: (5 to 95 Hz filter) +'dataset': + (of length 11965) + 'eeg': shape (128, 500) (128 channels, 500 points) + 'subject': 4 (out of 6 subjects) + 'label': 10 (out of 40) + 'image': 0 (out of 1996) + +'labels': + (of length 40) + label name (n........) + +'images': + (of length 1996) + image name (n........_.....) + + +In eeg-t-95 From f6db05c021152ae1851c8ad09349a905df3be007 Mon Sep 17 00:00:00 2001 From: Ulysse Durand <61126385+UlysseDurand@users.noreply.github.com> Date: Tue, 6 May 2025 17:17:31 +0530 Subject: [PATCH 02/16] Delete exps directory --- exps/results/generation/06-05-2025-16-14-36/README.md | 1 - 1 file changed, 1 deletion(-) delete mode 100644 exps/results/generation/06-05-2025-16-14-36/README.md diff --git a/exps/results/generation/06-05-2025-16-14-36/README.md b/exps/results/generation/06-05-2025-16-14-36/README.md deleted file mode 100644 index 4f92441..0000000 --- a/exps/results/generation/06-05-2025-16-14-36/README.md +++ /dev/null @@ -1 +0,0 @@ -{'seed': 2022, 'root_path': './', 'output_path': './exps/results/generation/06-05-2025-16-14-36', 'eeg_signals_path': './datasets/eeg_5_95_std.pth', 'splits_path': './datasets/block_splits_by_image_single.pth', 'roi': 'VC', 'patch_size': 4, 'embed_dim': 1024, 'depth': 24, 'num_heads': 16, 'mlp_ratio': 1.0, 'pretrain_gm_path': './pretrains', 'dataset': 'EEG', 'pretrain_mbm_path': './pretrains/eeg_pretrain/checkpoint.pth', 'img_size': 512, 'batch_size': 4, 'lr': 5.3e-05, 'num_epoch': 300, 'precision': 32, 'accumulate_grad': 1, 'crop_ratio': 0.2, 'global_pool': False, 'use_time_cond': True, 'clip_tune': True, 'cls_tune': False, 'subject': 4, 'eval_avg': True, 'num_samples': 5, 'ddim_steps': 250, 'HW': None, 'model_meta': None, 'checkpoint_path': None} From eb793837cd38c1ff376add859c68774723ca50bc Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Tue, 6 May 2025 17:18:47 +0530 Subject: [PATCH 03/16] better gitignore: --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index df115b1..e96e9ae 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ __pycache__/ *.egg-info/ datasets/* pretrains/models/v1-5-pruned.ckpt +exps/* +results/* From d5ed7339afc9fcec8fb427edc379859be2836884 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Tue, 6 May 2025 17:21:16 +0530 Subject: [PATCH 04/16] rectified README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 08c25a4..aa6c3c2 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS code/stageA1_eeg_pr In this stage, the cross-attention heads and pre-trained EEG encoder will be jointly optimized with EEG-image pairs. ```sh -python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ../dreamdiffusion/pretrains/eeg_pretrain/checkpoint.pth +python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ./pretrains/eeg_pretrain/checkpoint.pth ``` From 9e08ea72e6ac615b16a8bbdad04c1bf0cb138570 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Mon, 12 May 2025 17:35:05 +0530 Subject: [PATCH 05/16] Finally working --- .gitignore | 8 +++- README.md | 58 ++++++++++++++++++++++------ code/dataset.py | 4 +- code/gen_eval_eeg.py | 5 +-- code/sc_mbm/mae_for_eeg.py | 2 +- env.yaml | 5 ++- env114.yaml | 31 --------------- pretrains/generation/changeconfig.py | 13 +++++++ pretrains/generation/showconfig.py | 27 +++++++++++++ 9 files changed, 100 insertions(+), 53 deletions(-) delete mode 100644 env114.yaml create mode 100644 pretrains/generation/changeconfig.py create mode 100644 pretrains/generation/showconfig.py diff --git a/.gitignore b/.gitignore index e96e9ae..cc0ac2e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,13 @@ .DS_Store __pycache__/ - +dist/ *.egg-info/ + datasets/* -pretrains/models/v1-5-pruned.ckpt exps/* results/* + +pretrains/models/v1-5-pruned.ckpt +pretrains/eeg-pretrain/*.pth +pretrains/generation/*.pth diff --git a/README.md b/README.md index aa6c3c2..995a914 100644 --- a/README.md +++ b/README.md @@ -10,15 +10,23 @@ This document introduces the precesedures required for replicating the results i ## Abstract This paper introduces DreamDiffusion, a novel method for generating high-quality images directly from brain electroencephalogram (EEG) signals, without the need to translate thoughts into text. DreamDiffusion leverages pre-trained text-to-image models and employs temporal masked signal modeling to pre-train the EEG encoder for effective and robust EEG representations. Additionally, the method further leverages the CLIP image encoder to provide extra supervision to better align EEG, text, and image embeddings with limited EEG-image pairs. Overall, the proposed method overcomes the challenges of using EEG signals for image generation, such as noise, limited information, and individual differences, and achieves promising results. Quantitative and qualitative results demonstrate the effectiveness of the proposed method as a significant step towards portable and low-cost "thoughts-to-image", with potential applications in neuroscience and computer vision. +# WARNING + +This code works in three steps, the first step isn't directly reproducible from this directory. +This step only needs many different EEG signals (around 120 000 EEGs are used in the paper), it is specified that +the datasets can be picked from the MOABB datasets, but the specific datasets aren't specified. + +[this repository](https://github.com/alinvdu/reproduce-dream-diffusion/) +proposes to use only the EEGs from the (EEG, Image) pairs used for the 2nd step. ## Overview ![pipeline](assets/eeg_pipeline.png) -The **datasets** folder and **pretrains** folder are not included in this repository. +The **datasets** folder is not included in this repository. Neither are the pretrained models' checkpoints. Please download them from [eeg](https://github.com/perceivelab/eeg_visual_classification) and put them in the root directory of this repository as shown below. We also provide a copy of the Imagenet subset [imagenet](https://drive.google.com/file/d/1y7I9bG1zKYqBM94odcox_eQjnP9HGo9-/view?usp=drive_link). -For Stable Diffusion, we just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). +For Stable Diffusion, we just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). File path | Description ``` @@ -30,9 +38,12 @@ File path | Description ┣ 📂 generation ┃ ┗ 📜 checkpoint_best.pth +┃ ┗ 📜 checkpoint.pth +┃ ┗ 📜 changeconfig.py +┃ ┗ 📜 showconfig.py ┣ 📂 eeg_pretain -┃ ┗ 📜 checkpoint.pth (pre-trained EEG encoder) +┃ ┗ 📜 checkpoint-eeg-500.pth (pre-trained EEG encoder) /datasets ┣ 📂 imageNet_images (subset of Imagenet) @@ -76,11 +87,19 @@ conda activate dreamdiffusion ## Download checkpoints -We also checkpoints to run the finetuing and decoding directly. +### Pre-training on the EEG data + +Please find the checkpoint after pretraining the EEG Autoencoder in [this repository](https://github.com/alinvdu/reproduce-dream-diffusion/). + +And put it in ```pretrains/eeg-pretrain/checkpoint-eeg-500.pth``` +## Training the models +Here is how to make the checkpoints (remember to move them from the +```results``` directory to the appropriate location in the ```pretrains``` +directory) -## Pre-training on EEG data +### Pre-training on EEG data You can download the dataset for pretraining from here [MOABB](https://github.com/NeuroTechX/moabb). @@ -101,20 +120,37 @@ Multiple-GPU (DDP) training is supported, run with python -m torch.distributed.launch --nproc_per_node=NUM_GPUS code/stageA1_eeg_pretrain.py ``` - - -## Finetune the Stable Diffusion with Pre-trained EEG Encoder +### Finetune the Stable Diffusion with Pre-trained EEG Encoder In this stage, the cross-attention heads and pre-trained EEG encoder will be jointly optimized with EEG-image pairs. +Before + ```sh -python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ./pretrains/eeg_pretrain/checkpoint.pth +python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ./pretrains/eeg_pretrain/checkpoint-eeg-500.pth ``` -## Generating Images with Trained Checkpoints +### Generating Images with Trained Checkpoints Run this stage with our provided checkpoints: Here we provide a checkpoint [ckpt](https://drive.google.com/file/d/1Ygplxe1TB68-aYu082bjc89nD8Ngklnc/view?usp=drive_link), which you may want to try. + +Do +```sh +cd pretrains/generation +python changeconfig.py +``` +to change some of the checkpoint's config parameters and save the modified +checkpoint as `checkpoint2.pth` + +You can print one checkpoint's config parameters with +```sh +cd pretrains/generation +python showconfig.py checkpoint.pth +``` + +Finally, generate images + ```sh -python3 code/gen_eval_eeg.py --dataset EEG --model_path pretrains/generation/checkpoint.pth +python3 code/gen_eval_eeg.py --dataset EEG --model_path pretrains/generation/checkpoint2.pth ``` diff --git a/code/dataset.py b/code/dataset.py index 56a6001..c17891b 100644 --- a/code/dataset.py +++ b/code/dataset.py @@ -276,7 +276,7 @@ def __init__(self, eeg_signals_path, image_transform=identity): self.data = loaded['dataset'] self.labels = loaded["labels"] self.images = loaded["images"] - self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/' + self.imagenet = './datasets/imageNet_images/' self.image_transform = image_transform self.num_voxels = 440 # Compute size @@ -316,7 +316,7 @@ def __init__(self, eeg_signals_path, image_transform=identity, subject = 4): self.data = loaded['dataset'] self.labels = loaded["labels"] self.images = loaded["images"] - self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/' + self.imagenet = './datasets/imageNet_images' self.image_transform = image_transform self.num_voxels = 440 self.data_len = 512 diff --git a/code/gen_eval_eeg.py b/code/gen_eval_eeg.py index 37e0454..c7ee884 100644 --- a/code/gen_eval_eeg.py +++ b/code/gen_eval_eeg.py @@ -68,10 +68,6 @@ def get_args_parser(): config = sd['config'] # update paths config.root_path = root - config.pretrain_mbm_path = './results/eeg_pretrain/19-02-2023-08-48-17/checkpoints/checkpoint.pth' - config.pretrain_gm_path = './pretrains/' - print(config.__dict__) - output_path = os.path.join(config.root_path, 'results', 'eval', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) @@ -99,6 +95,7 @@ def get_args_parser(): # num_voxels = dataset_test.num_voxels print(len(dataset_test)) # prepare pretrained mae + print(config.pretrain_mbm_path) pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu') # create generateive model generative_model = eLDM(pretrain_mbm_metafile, num_voxels, diff --git a/code/sc_mbm/mae_for_eeg.py b/code/sc_mbm/mae_for_eeg.py index 90ed837..d4b14ab 100644 --- a/code/sc_mbm/mae_for_eeg.py +++ b/code/sc_mbm/mae_for_eeg.py @@ -1,5 +1,5 @@ import sys -sys.path.append('../dreamdiffusion/code/') +sys.path.append('./code/') # print(sys.path) import sc_mbm.utils as ut import torch diff --git a/env.yaml b/env.yaml index ebfecab..fd123ce 100644 --- a/env.yaml +++ b/env.yaml @@ -4,9 +4,10 @@ channels: - defaults dependencies: - python=3.8.5 - - pip>=20.3 + - pip>=20.3,<24.1 - pip: - - numpy + - numpy==1.19.5 + - transformers==4.36.2 - matplotlib - natsort - kornia diff --git a/env114.yaml b/env114.yaml deleted file mode 100644 index bb749a1..0000000 --- a/env114.yaml +++ /dev/null @@ -1,31 +0,0 @@ -name: dreamdiffusion-cu114 -channels: - - conda-forge - - pytorch - - defaults - -dependencies: - - python=3.8.5 - - pip>=20.3 - - pytorch-lightning=1.6.5 # installed via conda-forge - - pip: - - numpy - - matplotlib - - natsort - - kornia - - omegaconf==2.1.1 - - einops==0.3.0 - - torch-fidelity==0.3.0 - - --extra-index-url https://download.pytorch.org/whl/cu114 - - torch==1.12.1 - - torchvision==0.13.1 - - Pillow==9.0.1 - - timm==0.5.4 - - tqdm==4.64.0 - - wandb==0.12.21 - - torchmetrics==0.9.2 - - scikit-image - - lpips==0.1.4 - - transformers==4.20.1 - - -e ./code - diff --git a/pretrains/generation/changeconfig.py b/pretrains/generation/changeconfig.py new file mode 100644 index 0000000..ca10697 --- /dev/null +++ b/pretrains/generation/changeconfig.py @@ -0,0 +1,13 @@ +import torch +from collections import OrderedDict +import sys + +data = torch.load('checkpoint.pth') +config = data['config'] + +config.eeg_signals_path = "./datasets/eeg_5_95_std.pth" +config.splits_path = "./datasets/block_splits_by_image_single.pth" +config.pretrain_mbm_path = "./pretrains/eeg-pretrain/checkpoint-eeg-500.pth" +config.pretrain_gm_path = "./pretrains" + +torch.save(data, 'checkpoint2.pth') diff --git a/pretrains/generation/showconfig.py b/pretrains/generation/showconfig.py new file mode 100644 index 0000000..02188fa --- /dev/null +++ b/pretrains/generation/showconfig.py @@ -0,0 +1,27 @@ +import torch +from collections import OrderedDict +import sys + +data = torch.load(sys.argv[1]) +print(f"checkpoint keys: {data.keys()}") + +print(f"printing checkpoint['model_state_dict']") +model_state_dict = data['model_state_dict'] +#for k, v in model_state_dict.items(): +# print(k, v.cpu().size()) # or v.cpu().numpy() if you want the values +print("-----") + +print(f"printing checkpoint['config']") +config = data['config'] +print(config) + +print() +for key, value in config.__dict__.items(): + print(f"{key}: {value}") +print("-----") + + +print(f"printing checkpoint['state']") +state = data['state'] +print(state.shape) +print("-----") From 781834c5fd8d55b84f9df37c6fdb9bfb982026f9 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Mon, 12 May 2025 17:37:19 +0530 Subject: [PATCH 06/16] Without monREADME.md --- monREADME.md | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 monREADME.md diff --git a/monREADME.md b/monREADME.md deleted file mode 100644 index 2c84462..0000000 --- a/monREADME.md +++ /dev/null @@ -1,22 +0,0 @@ -# Dataset - -After going through the dataset: - -In eeg-5-95-std.pth: (5 to 95 Hz filter) -'dataset': - (of length 11965) - 'eeg': shape (128, 500) (128 channels, 500 points) - 'subject': 4 (out of 6 subjects) - 'label': 10 (out of 40) - 'image': 0 (out of 1996) - -'labels': - (of length 40) - label name (n........) - -'images': - (of length 1996) - image name (n........_.....) - - -In eeg-t-95 From e722ae897c685cad78dec69bdc4b86c066bb1df5 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Mon, 12 May 2025 17:51:13 +0530 Subject: [PATCH 07/16] Some rectifications --- README.md | 25 +++---------------------- code/gen_eval_eeg.py | 6 +++++- pretrains/generation/changeconfig.py | 13 ------------- pretrains/generation/showconfig.py | 27 --------------------------- 4 files changed, 8 insertions(+), 63 deletions(-) delete mode 100644 pretrains/generation/changeconfig.py delete mode 100644 pretrains/generation/showconfig.py diff --git a/README.md b/README.md index 995a914..dac4372 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ This step only needs many different EEG signals (around 120 000 EEGs are used in the datasets can be picked from the MOABB datasets, but the specific datasets aren't specified. [this repository](https://github.com/alinvdu/reproduce-dream-diffusion/) -proposes to use only the EEGs from the (EEG, Image) pairs used for the 2nd step. +proposes to use only the EEGs from the (EEG, Image) pairs used for the 2nd step but it makes poor results. ## Overview ![pipeline](assets/eeg_pipeline.png) @@ -26,7 +26,7 @@ proposes to use only the EEGs from the (EEG, Image) pairs used for the 2nd step. The **datasets** folder is not included in this repository. Neither are the pretrained models' checkpoints. Please download them from [eeg](https://github.com/perceivelab/eeg_visual_classification) and put them in the root directory of this repository as shown below. We also provide a copy of the Imagenet subset [imagenet](https://drive.google.com/file/d/1y7I9bG1zKYqBM94odcox_eQjnP9HGo9-/view?usp=drive_link). -For Stable Diffusion, we just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). +For Stable Diffusion, we just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). File path | Description ``` @@ -123,34 +123,15 @@ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS code/stageA1_eeg_pr ### Finetune the Stable Diffusion with Pre-trained EEG Encoder In this stage, the cross-attention heads and pre-trained EEG encoder will be jointly optimized with EEG-image pairs. -Before - ```sh python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ./pretrains/eeg_pretrain/checkpoint-eeg-500.pth ``` - ### Generating Images with Trained Checkpoints Run this stage with our provided checkpoints: Here we provide a checkpoint [ckpt](https://drive.google.com/file/d/1Ygplxe1TB68-aYu082bjc89nD8Ngklnc/view?usp=drive_link), which you may want to try. -Do -```sh -cd pretrains/generation -python changeconfig.py -``` -to change some of the checkpoint's config parameters and save the modified -checkpoint as `checkpoint2.pth` - -You can print one checkpoint's config parameters with -```sh -cd pretrains/generation -python showconfig.py checkpoint.pth -``` - -Finally, generate images - ```sh -python3 code/gen_eval_eeg.py --dataset EEG --model_path pretrains/generation/checkpoint2.pth +python3 code/gen_eval_eeg.py --dataset EEG --model_path pretrains/generation/checkpoint.pth ``` diff --git a/code/gen_eval_eeg.py b/code/gen_eval_eeg.py index c7ee884..c9e1720 100644 --- a/code/gen_eval_eeg.py +++ b/code/gen_eval_eeg.py @@ -68,6 +68,11 @@ def get_args_parser(): config = sd['config'] # update paths config.root_path = root + config.eeg_signals_path = "./datasets/eeg_5_95_std.pth" + config.splits_path = "./datasets/block_splits_by_image_single.pth" + config.pretrain_mbm_path = "./pretrains/eeg-pretrain/checkpoint-eeg-500.pth" + config.pretrain_gm_path = './pretrains/' + output_path = os.path.join(config.root_path, 'results', 'eval', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) @@ -95,7 +100,6 @@ def get_args_parser(): # num_voxels = dataset_test.num_voxels print(len(dataset_test)) # prepare pretrained mae - print(config.pretrain_mbm_path) pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu') # create generateive model generative_model = eLDM(pretrain_mbm_metafile, num_voxels, diff --git a/pretrains/generation/changeconfig.py b/pretrains/generation/changeconfig.py deleted file mode 100644 index ca10697..0000000 --- a/pretrains/generation/changeconfig.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch -from collections import OrderedDict -import sys - -data = torch.load('checkpoint.pth') -config = data['config'] - -config.eeg_signals_path = "./datasets/eeg_5_95_std.pth" -config.splits_path = "./datasets/block_splits_by_image_single.pth" -config.pretrain_mbm_path = "./pretrains/eeg-pretrain/checkpoint-eeg-500.pth" -config.pretrain_gm_path = "./pretrains" - -torch.save(data, 'checkpoint2.pth') diff --git a/pretrains/generation/showconfig.py b/pretrains/generation/showconfig.py deleted file mode 100644 index 02188fa..0000000 --- a/pretrains/generation/showconfig.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -from collections import OrderedDict -import sys - -data = torch.load(sys.argv[1]) -print(f"checkpoint keys: {data.keys()}") - -print(f"printing checkpoint['model_state_dict']") -model_state_dict = data['model_state_dict'] -#for k, v in model_state_dict.items(): -# print(k, v.cpu().size()) # or v.cpu().numpy() if you want the values -print("-----") - -print(f"printing checkpoint['config']") -config = data['config'] -print(config) - -print() -for key, value in config.__dict__.items(): - print(f"{key}: {value}") -print("-----") - - -print(f"printing checkpoint['state']") -state = data['state'] -print(state.shape) -print("-----") From e880aa366b2b651b7b0d730ad7870789df242bdf Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Mon, 12 May 2025 18:23:33 +0530 Subject: [PATCH 08/16] better --- README.md | 14 +++++++------- code/eeg_ldm.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index dac4372..b00173b 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,6 @@ This code works in three steps, the first step isn't directly reproducible from This step only needs many different EEG signals (around 120 000 EEGs are used in the paper), it is specified that the datasets can be picked from the MOABB datasets, but the specific datasets aren't specified. -[this repository](https://github.com/alinvdu/reproduce-dream-diffusion/) -proposes to use only the EEGs from the (EEG, Image) pairs used for the 2nd step but it makes poor results. ## Overview ![pipeline](assets/eeg_pipeline.png) @@ -39,8 +37,6 @@ File path | Description ┣ 📂 generation ┃ ┗ 📜 checkpoint_best.pth ┃ ┗ 📜 checkpoint.pth -┃ ┗ 📜 changeconfig.py -┃ ┗ 📜 showconfig.py ┣ 📂 eeg_pretain ┃ ┗ 📜 checkpoint-eeg-500.pth (pre-trained EEG encoder) @@ -85,14 +81,15 @@ conda env create -f env.yaml conda activate dreamdiffusion ``` -## Download checkpoints +## Download checkpoint after pre-training on EEG data -### Pre-training on the EEG data +No checkpoint provided -Please find the checkpoint after pretraining the EEG Autoencoder in [this repository](https://github.com/alinvdu/reproduce-dream-diffusion/). +To make the code work (poorly), please find a checkpoint after pretraining the EEG Autoencoder in [this repository](https://github.com/alinvdu/reproduce-dream-diffusion/). They propose to use only the EEGs from the (EEG, Image) pairs but it makes poor results as this is few data. And put it in ```pretrains/eeg-pretrain/checkpoint-eeg-500.pth``` + ## Training the models Here is how to make the checkpoints (remember to move them from the @@ -103,6 +100,9 @@ directory) You can download the dataset for pretraining from here [MOABB](https://github.com/NeuroTechX/moabb). +The datasets used to get the papers's results aren't specified, but you can put +any EEG dataset, put them in ```datasets/mne_data/``` in a `.npy` format. + To perform the pre-training from scratch with defaults parameters, run ```sh python3 code/stageA1_eeg_pretrain.py diff --git a/code/eeg_ldm.py b/code/eeg_ldm.py index a780174..4aa1c30 100644 --- a/code/eeg_ldm.py +++ b/code/eeg_ldm.py @@ -12,7 +12,7 @@ import copy # own code -from config import Config_Generative_Model +from config import Config_Generative_Model, Config_MBM_EEG from dataset import create_EEG_dataset from dc_ldm.ldm_for_eeg import eLDM from eval_metrics import get_similarity_metric From 0139dec7aa873d82d0a75df17a7ce8c27f6340d1 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Mon, 12 May 2025 18:25:10 +0530 Subject: [PATCH 09/16] easier for PR --- README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/README.md b/README.md index b00173b..14a1bd3 100644 --- a/README.md +++ b/README.md @@ -10,13 +10,6 @@ This document introduces the precesedures required for replicating the results i ## Abstract This paper introduces DreamDiffusion, a novel method for generating high-quality images directly from brain electroencephalogram (EEG) signals, without the need to translate thoughts into text. DreamDiffusion leverages pre-trained text-to-image models and employs temporal masked signal modeling to pre-train the EEG encoder for effective and robust EEG representations. Additionally, the method further leverages the CLIP image encoder to provide extra supervision to better align EEG, text, and image embeddings with limited EEG-image pairs. Overall, the proposed method overcomes the challenges of using EEG signals for image generation, such as noise, limited information, and individual differences, and achieves promising results. Quantitative and qualitative results demonstrate the effectiveness of the proposed method as a significant step towards portable and low-cost "thoughts-to-image", with potential applications in neuroscience and computer vision. -# WARNING - -This code works in three steps, the first step isn't directly reproducible from this directory. -This step only needs many different EEG signals (around 120 000 EEGs are used in the paper), it is specified that -the datasets can be picked from the MOABB datasets, but the specific datasets aren't specified. - - ## Overview ![pipeline](assets/eeg_pipeline.png) From 8c5229d6a350e416dcc2a52414bb4e833d128633 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Mon, 12 May 2025 18:28:44 +0530 Subject: [PATCH 10/16] better --- README.md | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 14a1bd3..273c058 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,6 @@ File path | Description ┃ ┗ 📜 v1-5-pruned.ckpt ┣ 📂 generation -┃ ┗ 📜 checkpoint_best.pth ┃ ┗ 📜 checkpoint.pth ┣ 📂 eeg_pretain @@ -82,14 +81,10 @@ To make the code work (poorly), please find a checkpoint after pretraining the E And put it in ```pretrains/eeg-pretrain/checkpoint-eeg-500.pth``` +## Pre-training on EEG data -## Training the models - -Here is how to make the checkpoints (remember to move them from the -```results``` directory to the appropriate location in the ```pretrains``` -directory) - -### Pre-training on EEG data +remember to move generated checkpoints from the ```results``` directory to the +appropriate location in the ```pretrains``` directory You can download the dataset for pretraining from here [MOABB](https://github.com/NeuroTechX/moabb). @@ -113,14 +108,18 @@ Multiple-GPU (DDP) training is supported, run with python -m torch.distributed.launch --nproc_per_node=NUM_GPUS code/stageA1_eeg_pretrain.py ``` -### Finetune the Stable Diffusion with Pre-trained EEG Encoder +## Finetune the Stable Diffusion with Pre-trained EEG Encoder + +remember to move generated checkpoints from the ```results``` directory to the +appropriate location in the ```pretrains``` directory + In this stage, the cross-attention heads and pre-trained EEG encoder will be jointly optimized with EEG-image pairs. ```sh python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ./pretrains/eeg_pretrain/checkpoint-eeg-500.pth ``` -### Generating Images with Trained Checkpoints +## Generating Images with Trained Checkpoints Run this stage with our provided checkpoints: Here we provide a checkpoint [ckpt](https://drive.google.com/file/d/1Ygplxe1TB68-aYu082bjc89nD8Ngklnc/view?usp=drive_link), which you may want to try. ```sh From 333b2a1e4af7e6cf9ee45f9d3fd85b538dce6dfd Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Tue, 13 May 2025 10:37:47 +0530 Subject: [PATCH 11/16] added datasetconverter --- .gitignore | 3 ++- datasets/converter.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 datasets/converter.py diff --git a/.gitignore b/.gitignore index cc0ac2e..7887784 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,8 @@ __pycache__/ dist/ *.egg-info/ -datasets/* +datasets/*.pth +datasets/*.JPEG exps/* results/* diff --git a/datasets/converter.py b/datasets/converter.py new file mode 100644 index 0000000..b5d9c9a --- /dev/null +++ b/datasets/converter.py @@ -0,0 +1,13 @@ +import torch +import numpy as np + +data = torch.load("eeg_5_95_std.pth") + +print(f"data: dict with keys {data.keys()}") +print(f"data['dataset']: list of length {len(data['dataset'])}") +print(f"data['dataset'][0]['eeg']: {data['dataset'][0]['eeg'].shape}") + +res1 = [sample['eeg'] for sample in data['dataset']] +print(len(res), [res[i].shape for i in range(5)]) + +np.save("output.npy", np.array(res1, dtype=object)) From 7188ea3c7c4fc4a3fe4fc814f0401a3bcf079bec Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Wed, 14 May 2025 17:48:27 +0530 Subject: [PATCH 12/16] Added a dataset downloader, changed code/dataset.py --- .gitignore | 3 +- code/dataset.py | 9 +-- datasets/dl_moabb_datasets.py | 68 +++++++++++++++++++ datasets/{converter.py => min_size_sample.py} | 5 +- env_moabb.yaml | 18 +++++ 5 files changed, 95 insertions(+), 8 deletions(-) create mode 100644 datasets/dl_moabb_datasets.py rename datasets/{converter.py => min_size_sample.py} (74%) create mode 100644 env_moabb.yaml diff --git a/.gitignore b/.gitignore index 7887784..4893af7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,9 @@ __pycache__/ dist/ *.egg-info/ +datasets/mne_data/ datasets/*.pth -datasets/*.JPEG +datasets/**/*.JPEG exps/* results/* diff --git a/code/dataset.py b/code/dataset.py index c17891b..6ee56c4 100644 --- a/code/dataset.py +++ b/code/dataset.py @@ -113,17 +113,18 @@ def __init__(self, path='./datasets/mne_data/', roi='VC', patch_size=16, transfo images = [] self.input_paths = [str(f) for f in sorted(Path(path).rglob('*')) if is_npy_ext(f) and os.path.isfile(f)] + self.inputs = [np.load(data_path) for data_path in self.input_paths] + self.real_input = np.concatenate(self.inputs, axis=0) + assert len(self.input_paths) != 0, 'No data found' self.data_len = 512 self.data_chan = 128 def __len__(self): - return len(self.input_paths) + return len(self.real_input) def __getitem__(self, index): - data_path = self.input_paths[index] - - data = np.load(data_path) + data = self.real_input[index] if data.shape[-1] > self.data_len: idx = np.random.randint(0, int(data.shape[-1] - self.data_len)+1) diff --git a/datasets/dl_moabb_datasets.py b/datasets/dl_moabb_datasets.py new file mode 100644 index 0000000..c4b1016 --- /dev/null +++ b/datasets/dl_moabb_datasets.py @@ -0,0 +1,68 @@ +import warnings +from urllib3.exceptions import InsecureRequestWarning +from mne.io import BaseRaw +import os + +def find_raw_objects(data, types=(BaseRaw), path=""): + found = [] + + if isinstance(data, dict): + for key, value in data.items(): + new_path = f"{path}.{key}" if path else key + found.extend(find_raw_objects(value, types, new_path)) + + elif isinstance(data, list): + for idx, item in enumerate(data): + new_path = f"{path}[{idx}]" + found.extend(find_raw_objects(item, types, new_path)) + + elif isinstance(data, types): + found.append((path, data)) + + return found + +print() +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=InsecureRequestWarning) +print("WARNING : All 'FutureWarning' and 'InsecureRequestWarning' have been disabled") + +import numpy as np +from moabb.datasets import * + +datasets = [Schirrmeister2017(), GrosseWentrup2009()] +sample_size = 490 + +datas = [] +for dataset in datasets: + try: + print("\n#####") + print(dataset) + print("#####\n") + data = dataset.get_data(subjects=dataset.subject_list[:2]) + for raw_path, raw_data in find_raw_objects(data): + print() + npdata = raw_data.pick_types(eeg=True).get_data() + print(f"Shape of data: {npdata.shape}") + + print(f"Splitting into samples of size {sample_size}") + length = npdata.shape[1] + samples = [] + i=0 + while (i + 1) * sample_size < length: + sample = npdata[:, i * sample_size : (i+1) * sample_size] + samples.append(sample) + i += 1 + + res = np.array(samples) + print(f"New shape: {res.shape}") + + + whole_path = 'mne_data/'+dataset.code+"/"+raw_path+'.npy' + print(f"Saving in {whole_path}") + + os.makedirs(os.path.dirname(whole_path), exist_ok=True) + np.save(whole_path, res) + print() + + except Exception as e: + print(e) diff --git a/datasets/converter.py b/datasets/min_size_sample.py similarity index 74% rename from datasets/converter.py rename to datasets/min_size_sample.py index b5d9c9a..11a0950 100644 --- a/datasets/converter.py +++ b/datasets/min_size_sample.py @@ -8,6 +8,5 @@ print(f"data['dataset'][0]['eeg']: {data['dataset'][0]['eeg'].shape}") res1 = [sample['eeg'] for sample in data['dataset']] -print(len(res), [res[i].shape for i in range(5)]) - -np.save("output.npy", np.array(res1, dtype=object)) +print() +print(f"Smallest sample has size {min(arr.shape[1] for arr in res1)}") diff --git a/env_moabb.yaml b/env_moabb.yaml new file mode 100644 index 0000000..5971711 --- /dev/null +++ b/env_moabb.yaml @@ -0,0 +1,18 @@ +name: dreamdiffusion-moabb +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip + - numpy + - scipy + - scikit-learn + - matplotlib + - pandas + - seaborn + - mne + - joblib + - pyxdf + - pip: + - moabb From 8ebe11fbe5b977ed4c8803179b45c7f5898c2b70 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Fri, 16 May 2025 10:36:03 +0530 Subject: [PATCH 13/16] ready to code without gpu --- .gitignore | 9 +- code/TimeMAE | 1 + ...pretrain.py => stageA1_eeg_pretrain_dd.py} | 0 code/stageA1_eeg_pretrain_timeMAE.py | 300 ++++++++++++++++++ 4 files changed, 307 insertions(+), 3 deletions(-) create mode 160000 code/TimeMAE rename code/{stageA1_eeg_pretrain.py => stageA1_eeg_pretrain_dd.py} (100%) create mode 100644 code/stageA1_eeg_pretrain_timeMAE.py diff --git a/.gitignore b/.gitignore index 4893af7..b205a8c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,11 +4,14 @@ __pycache__/ dist/ *.egg-info/ -datasets/mne_data/ +datasets/eeg_train/ +datasets/eeg_test/ datasets/*.pth datasets/**/*.JPEG -exps/* -results/* +exps/ +exp/ +results/ +wandb/ pretrains/models/v1-5-pruned.ckpt pretrains/eeg-pretrain/*.pth diff --git a/code/TimeMAE b/code/TimeMAE new file mode 160000 index 0000000..f4401b5 --- /dev/null +++ b/code/TimeMAE @@ -0,0 +1 @@ +Subproject commit f4401b5a3b80056a385078af0caa1a58a04115e4 diff --git a/code/stageA1_eeg_pretrain.py b/code/stageA1_eeg_pretrain_dd.py similarity index 100% rename from code/stageA1_eeg_pretrain.py rename to code/stageA1_eeg_pretrain_dd.py diff --git a/code/stageA1_eeg_pretrain_timeMAE.py b/code/stageA1_eeg_pretrain_timeMAE.py new file mode 100644 index 0000000..d83a26e --- /dev/null +++ b/code/stageA1_eeg_pretrain_timeMAE.py @@ -0,0 +1,300 @@ +import os, sys +import numpy as np +import torch +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel +import argparse +import time +import timm.optim.optim_factory as optim_factory +import datetime +import matplotlib.pyplot as plt +import wandb +import copy + +from config import Config_MBM_EEG +from dataset import eeg_pretrain_dataset +from sc_mbm.mae_for_eeg import MAEforEEG +from sc_mbm.trainer import train_one_epoch +from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler +from sc_mbm.utils import save_model + +from TimeMAE.model.TimeMAE import TimeMAE + +os.environ["WANDB_START_METHOD"] = "thread" +os.environ['WANDB_DIR'] = "." + +class wandb_logger: + def __init__(self, config): + wandb.init( + project="dreamdiffusion", + anonymous="allow", + group='stageA_sc-mbm', + config=config, + reinit=True) + + self.config = config + self.step = None + + def log(self, name, data, step=None): + if step is None: + wandb.log({name: data}) + else: + wandb.log({name: data}, step=step) + self.step = step + + def watch_model(self, *args, **kwargs): + wandb.watch(*args, **kwargs) + + def log_image(self, name, fig): + if self.step is None: + wandb.log({name: wandb.Image(fig)}) + else: + wandb.log({name: wandb.Image(fig)}, step=self.step) + + def finish(self): + wandb.finish(quiet=True) + +def get_args_parser(): + parser = argparse.ArgumentParser('MBM pre-training for fMRI', add_help=False) + + # Training Parameters + parser.add_argument('--lr', type=float) + parser.add_argument('--weight_decay', type=float) + parser.add_argument('--num_epoch', type=int) + parser.add_argument('--batch_size', type=int) + parser.add_argument('--momentum', type=float) + + # Model Parameters + parser.add_argument('--mask_ratio', type=float) + parser.add_argument('--patch_size', type=int) + parser.add_argument('--embed_dim', type=int) + parser.add_argument('--decoder_embed_dim', type=int) + parser.add_argument('--depth', type=int) + parser.add_argument('--num_heads', type=int) + parser.add_argument('--decoder_num_heads', type=int) + parser.add_argument('--mlp_ratio', type=float) + parser.add_argument('--wave_length', type=int) + parser.add_argument('--attn_heads', type=int) + parser.add_argument('--layers', type=int) + parser.add_argument('--dropout', type=float) + parser.add_argument('--enable_res_parameter', type=int) + parser.add_argument('--vocab_size', type=int) + parser.add_argument('--reg_layers', type=int) + parser.add_argument('--num_class', type=int) + + # Project setting + parser.add_argument('--root_path', type=str) + parser.add_argument('--seed', type=str) + parser.add_argument('--roi', type=str) + parser.add_argument('--aug_times', type=int) + parser.add_argument('--num_sub_limit', type=int) + + parser.add_argument('--include_hcp', type=bool) + parser.add_argument('--include_kam', type=bool) + + parser.add_argument('--use_nature_img_loss', type=bool) + parser.add_argument('--img_recon_weight', type=float) + + # distributed training parameters + parser.add_argument('--local_rank', type=int) + + return parser + +def create_readme(config, path): + print(config.__dict__) + with open(os.path.join(path, 'README.md'), 'w+') as f: + print(config.__dict__, file=f) + +def fmri_transform(x, sparse_rate=0.2): + # x: 1, num_voxels + x_aug = copy.deepcopy(x) + idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False) + x_aug[idx] = 0 + return torch.FloatTensor(x_aug) + +def main(config): + # print('num of gpu:') + # print(torch.cuda.device_count()) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(config.local_rank) + torch.distributed.init_process_group(backend='nccl') + output_path = os.path.join(config.root_path, 'results', 'eeg_pretrain', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) + config.output_path = output_path + logger = wandb_logger(config) if config.local_rank == 0 else None + # logger = None + + if config.local_rank == 0: + os.makedirs(output_path, exist_ok=True) + create_readme(config, output_path) + + device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu') + torch.manual_seed(config.seed) + np.random.seed(config.seed) + + # create dataset and dataloader + dataset_pretrain = eeg_pretrain_dataset(path='./datasets/eeg_test/') + + print(f'Dataset size: {len(dataset_pretrain)}\n Time len: {dataset_pretrain.data_len}') + sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=config.local_rank) if torch.cuda.device_count() > 1 else None + + dataloader_eeg = DataLoader(dataset_pretrain, batch_size=config.batch_size, sampler=sampler, + shuffle=(sampler is None), pin_memory=True) + + # create model + config.time_len = dataset_pretrain.data_len + config.d_model = config.embed_dim + config.device = device + config.data_shape = dataset_pretrain.__getitem__(0)['eeg'].transpose(0,1).shape + print(f"Input data shape: {config.data_shape}") + + model = TimeMAE(config) + + model.to(device) + model_without_ddp = model + if torch.cuda.device_count() > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank, find_unused_parameters=config.use_nature_img_loss) + + param_groups = optim_factory.add_weight_decay(model, config.weight_decay) + optimizer = torch.optim.AdamW(param_groups, lr=config.lr, betas=(0.9, 0.95)) + print(optimizer) + loss_scaler = NativeScaler() + + if logger is not None: + logger.watch_model(model,log='all', log_freq=1000) + + cor_list = [] + start_time = time.time() + print('Start Training the EEG MAE ... ...') + img_feature_extractor = None + preprocess = None + if config.use_nature_img_loss: + from torchvision.models import resnet50, ResNet50_Weights + from torchvision.models.feature_extraction import create_feature_extractor + weights = ResNet50_Weights.DEFAULT + preprocess = weights.transforms() + m = resnet50(weights=weights) + img_feature_extractor = create_feature_extractor(m, return_nodes={f'layer2': 'layer2'}).to(device).eval() + for param in img_feature_extractor.parameters(): + param.requires_grad = False + + for ep in range(config.num_epoch): + + if torch.cuda.device_count() > 1: + sampler.set_epoch(ep) # to shuffle the data at every epoch + cor = train_one_epoch(model, dataloader_eeg, optimizer, device, ep, loss_scaler, logger, config, start_time, model_without_ddp, + img_feature_extractor, preprocess) + cor_list.append(cor) + if (ep % 20 == 0 or ep + 1 == config.num_epoch) and config.local_rank == 0: #and ep != 0 + # save models + # if True: + save_model(config, ep, model_without_ddp, optimizer, loss_scaler, os.path.join(output_path,'checkpoints')) + # plot figures + # plot_recon_figures(model, device, dataset_pretrain, output_path, 5, config, logger, model_without_ddp) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + if logger is not None: + logger.log('max cor', np.max(cor_list), step=config.num_epoch-1) + logger.finish() + return + +@torch.no_grad() +def plot_recon_figures(model, device, dataset, output_path, num_figures = 5, config=None, logger=None, model_without_ddp=None): + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) + model.eval() + fig, axs = plt.subplots(num_figures, 3, figsize=(30,15)) + fig.tight_layout() + axs[0,0].set_title('Ground-truth') + axs[0,1].set_title('Masked Ground-truth') + axs[0,2].set_title('Reconstruction') + + for ax in axs: + sample = next(iter(dataloader))['eeg'] + sample = sample.to(device) + _, pred, mask = model(sample, mask_ratio=config.mask_ratio) + # sample_with_mask = model_without_ddp.patchify(sample.transpose(1,2))[0].to('cpu').numpy().reshape(-1, model_without_ddp.patch_size) + sample_with_mask = sample.to('cpu').squeeze(0)[0].numpy().reshape(-1, model_without_ddp.patch_size) + # pred = model_without_ddp.unpatchify(pred.transpose(1,2)).to('cpu').squeeze(0)[0].unsqueeze(0).numpy() + # sample = sample.to('cpu').squeeze(0)[0].unsqueeze(0).numpy() + pred = model_without_ddp.unpatchify(pred).to('cpu').squeeze(0)[0].numpy() + # pred = model_without_ddp.unpatchify(model_without_ddp.patchify(sample.transpose(1,2))).to('cpu').squeeze(0)[0].numpy() + sample = sample.to('cpu').squeeze(0)[0].numpy() + mask = mask.to('cpu').numpy().reshape(-1) + + cor = np.corrcoef([pred, sample])[0,1] + + x_axis = np.arange(0, sample.shape[-1]) + # groundtruth + ax[0].plot(x_axis, sample) + # groundtruth with mask + s = 0 + for x, m in zip(sample_with_mask,mask): + if m == 0: + ax[1].plot(x_axis[s:s+len(x)], x, color='#1f77b4') + s += len(x) + # pred + ax[2].plot(x_axis, pred) + ax[2].set_ylabel('cor: %.4f'%cor, weight = 'bold') + ax[2].yaxis.set_label_position("right") + + fig_name = 'reconst-%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")) + fig.savefig(os.path.join(output_path, f'{fig_name}.png')) + if logger is not None: + logger.log_image('reconst', fig) + plt.close(fig) + + +@torch.no_grad() +def plot_recon_figures2(model, device, dataset, output_path, num_figures = 5, config=None, logger=None, model_without_ddp=None): + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) + model.eval() + fig, axs = plt.subplots(num_figures, 2, figsize=(20,15)) + fig.tight_layout() + axs[0,0].set_title('Ground-truth') + # axs[0,1].set_title('Masked Ground-truth') + axs[0,1].set_title('Reconstruction') + + for ax in axs: + sample = next(iter(dataloader))['eeg'] + sample = sample.to(device) + _, pred, mask = model(sample, mask_ratio=config.mask_ratio) + # sample_with_mask = model_without_ddp.patchify(sample.transpose(1,2))[0].to('cpu').numpy().reshape(-1, model_without_ddp.patch_size) + sample_with_mask = sample.to('cpu').squeeze(0)[0].numpy().reshape(-1, model_without_ddp.patch_size) + # pred = model_without_ddp.unpatchify(pred.transpose(1,2)).to('cpu').squeeze(0)[0].unsqueeze(0).numpy() + # sample = sample.to('cpu').squeeze(0)[0].unsqueeze(0).numpy() + pred = model_without_ddp.unpatchify(pred).to('cpu').squeeze(0)[0].numpy() + # pred = model_without_ddp.unpatchify(model_without_ddp.patchify(sample.transpose(1,2))).to('cpu').squeeze(0)[0].numpy() + sample = sample.to('cpu').squeeze(0)[0].numpy() + cor = np.corrcoef([pred, sample])[0,1] + + x_axis = np.arange(0, sample.shape[-1]) + # groundtruth + ax[0].plot(x_axis, sample) + + ax[1].plot(x_axis, pred) + ax[1].set_ylabel('cor: %.4f'%cor, weight = 'bold') + ax[1].yaxis.set_label_position("right") + + fig_name = 'reconst-%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")) + fig.savefig(os.path.join(output_path, f'{fig_name}.png')) + if logger is not None: + logger.log_image('reconst', fig) + plt.close(fig) + + +def update_config(args, config): + for attr in config.__dict__: + if hasattr(args, attr): + if getattr(args, attr) != None: + setattr(config, attr, getattr(args, attr)) + return config + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + config = Config_MBM_EEG() + config = update_config(args, config) + main(config) \ No newline at end of file From baf525fc6db98a438ff8688d5054edc53cc04df9 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Fri, 16 May 2025 10:42:44 +0530 Subject: [PATCH 14/16] added submodule --- .gitmodules | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .gitmodules diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..09b66ad --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "code/TimeMAE"] + path = code/TimeMAE + url = git@github.com:Mingyue-Cheng/TimeMAE.git From 8bd1e27f0ca3a43cff96790904ea2a45fdb8a78c Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Fri, 16 May 2025 15:04:07 +0530 Subject: [PATCH 15/16] simpler eegpretrain arguments were useless --- code/config.py | 1 - code/dataset.py | 6 +++--- code/stageA1_eeg_pretrain_dd.py | 5 +---- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/code/config.py b/code/config.py index 93d5cf8..45d0699 100644 --- a/code/config.py +++ b/code/config.py @@ -34,7 +34,6 @@ def __init__(self): self.root_path = './' self.output_path = './exps/' self.seed = 2022 - self.roi = 'VC' self.aug_times = 1 self.num_sub_limit = None self.include_hcp = True diff --git a/code/dataset.py b/code/dataset.py index 6ee56c4..b72665d 100644 --- a/code/dataset.py +++ b/code/dataset.py @@ -106,17 +106,17 @@ def is_npy_ext(fname: Union[str, Path]) -> bool: return f'{ext}' == 'npy'# type: ignore class eeg_pretrain_dataset(Dataset): - def __init__(self, path='./datasets/mne_data/', roi='VC', patch_size=16, transform=identity, aug_times=2, - num_sub_limit=None, include_kam=False, include_hcp=True): + def __init__(self, path): super(eeg_pretrain_dataset, self).__init__() data = [] images = [] self.input_paths = [str(f) for f in sorted(Path(path).rglob('*')) if is_npy_ext(f) and os.path.isfile(f)] + assert len(self.input_paths) != 0, 'No data found' + self.inputs = [np.load(data_path) for data_path in self.input_paths] self.real_input = np.concatenate(self.inputs, axis=0) - assert len(self.input_paths) != 0, 'No data found' self.data_len = 512 self.data_chan = 128 diff --git a/code/stageA1_eeg_pretrain_dd.py b/code/stageA1_eeg_pretrain_dd.py index ada1264..1e0fdd3 100644 --- a/code/stageA1_eeg_pretrain_dd.py +++ b/code/stageA1_eeg_pretrain_dd.py @@ -74,7 +74,6 @@ def get_args_parser(): # Project setting parser.add_argument('--root_path', type=str) parser.add_argument('--seed', type=str) - parser.add_argument('--roi', type=str) parser.add_argument('--aug_times', type=int) parser.add_argument('--num_sub_limit', type=int) @@ -121,9 +120,7 @@ def main(config): np.random.seed(config.seed) # create dataset and dataloader - dataset_pretrain = eeg_pretrain_dataset(path='./datasets/mne_data/', roi=config.roi, patch_size=config.patch_size, - transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, - include_kam=config.include_kam, include_hcp=config.include_hcp) + dataset_pretrain = eeg_pretrain_dataset(path='./datasets/eeg_train/') print(f'Dataset size: {len(dataset_pretrain)}\n Time len: {dataset_pretrain.data_len}') sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=config.local_rank) if torch.cuda.device_count() > 1 else None From 019dbf09ed0ac6c60e4ce1a09060de6cf695d2d7 Mon Sep 17 00:00:00 2001 From: Ulysse Durand Date: Fri, 16 May 2025 15:22:07 +0530 Subject: [PATCH 16/16] added --- README.md | 2 +- code/config.py | 14 +++++++------- code/stageA1_eeg_pretrain_dd.py | 4 ++-- datasets/dl_moabb_datasets.py | 12 ++++++++---- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 273c058..af9e02b 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ appropriate location in the ```pretrains``` directory You can download the dataset for pretraining from here [MOABB](https://github.com/NeuroTechX/moabb). The datasets used to get the papers's results aren't specified, but you can put -any EEG dataset, put them in ```datasets/mne_data/``` in a `.npy` format. +any EEG dataset, put them in ```datasets/eeg_train/``` and ```dataset/eeg_test/``` in a `.npy` format. To perform the pre-training from scratch with defaults parameters, run ```sh diff --git a/code/config.py b/code/config.py index 45d0699..f1d91b1 100644 --- a/code/config.py +++ b/code/config.py @@ -16,19 +16,19 @@ def __init__(self): self.min_lr = 0. self.weight_decay = 0.05 self.num_epoch = 500 - self.warmup_epochs = 40 - self.batch_size = 100 + self.warmup_epochs = 10 + self.batch_size = 128 self.clip_grad = 0.8 # Model Parameters - self.mask_ratio = 0.1 + self.mask_ratio = 0.5 self.patch_size = 4 # 1 self.embed_dim = 1024 #256 # has to be a multiple of num_heads self.decoder_embed_dim = 512 #128 self.depth = 24 self.num_heads = 16 self.decoder_num_heads = 16 - self.mlp_ratio = 1.0 + self.mlp_ratio = 0.8 # Project setting self.root_path = './' @@ -70,7 +70,7 @@ def __init__(self): self.lr = 5.3e-5 self.weight_decay = 0.05 self.num_epoch = 15 - self.batch_size = 16 if self.dataset == 'GOD' else 4 + self.batch_size = 32 if self.dataset == 'GOD' else 16 self.mask_ratio = 0.5 self.accum_iter = 1 self.clip_grad = 0.8 @@ -110,7 +110,7 @@ def __init__(self): np.random.seed(self.seed) # finetune parameters - self.batch_size = 5 if self.dataset == 'GOD' else 25 + self.batch_size = 32 if self.dataset == 'GOD' else 16 self.lr = 5.3e-5 self.num_epoch = 500 @@ -161,7 +161,7 @@ def __init__(self): np.random.seed(self.seed) # finetune parameters - self.batch_size = 5 if self.dataset == 'GOD' else 25 + self.batch_size = 32 if self.dataset == 'GOD' else 16 self.lr = 5.3e-5 self.num_epoch = 50 diff --git a/code/stageA1_eeg_pretrain_dd.py b/code/stageA1_eeg_pretrain_dd.py index 1e0fdd3..4913774 100644 --- a/code/stageA1_eeg_pretrain_dd.py +++ b/code/stageA1_eeg_pretrain_dd.py @@ -108,8 +108,8 @@ def main(config): torch.distributed.init_process_group(backend='nccl') output_path = os.path.join(config.root_path, 'results', 'eeg_pretrain', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) config.output_path = output_path - # logger = wandb_logger(config) if config.local_rank == 0 else None - logger = None + logger = wandb_logger(config) if config.local_rank == 0 else None + # logger = None if config.local_rank == 0: os.makedirs(output_path, exist_ok=True) diff --git a/datasets/dl_moabb_datasets.py b/datasets/dl_moabb_datasets.py index c4b1016..e79c081 100644 --- a/datasets/dl_moabb_datasets.py +++ b/datasets/dl_moabb_datasets.py @@ -32,13 +32,12 @@ def find_raw_objects(data, types=(BaseRaw), path=""): datasets = [Schirrmeister2017(), GrosseWentrup2009()] sample_size = 490 -datas = [] -for dataset in datasets: +def dl_dataset(dataset, subjects, outpath): try: print("\n#####") print(dataset) print("#####\n") - data = dataset.get_data(subjects=dataset.subject_list[:2]) + data = dataset.get_data(subjects=subjects) for raw_path, raw_data in find_raw_objects(data): print() npdata = raw_data.pick_types(eeg=True).get_data() @@ -57,7 +56,7 @@ def find_raw_objects(data, types=(BaseRaw), path=""): print(f"New shape: {res.shape}") - whole_path = 'mne_data/'+dataset.code+"/"+raw_path+'.npy' + whole_path = outpath+'/'+dataset.code+"/"+raw_path+'.npy' print(f"Saving in {whole_path}") os.makedirs(os.path.dirname(whole_path), exist_ok=True) @@ -66,3 +65,8 @@ def find_raw_objects(data, types=(BaseRaw), path=""): except Exception as e: print(e) + +for dataset in datasets: + subjects = dataset.subject_list + dl_dataset(dataset, subjects[:1], "eeg_train/") + dl_dataset(dataset, subjects[1:2], "eeg_test/")