diff --git a/.gitignore b/.gitignore index f6b4615..828bafc 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ benchopt.ini .DS_Store coverage.xml + +tmp diff --git a/benchmark_utils/custom_models.py b/benchmark_utils/custom_models.py new file mode 100644 index 0000000..f38f2c6 --- /dev/null +++ b/benchmark_utils/custom_models.py @@ -0,0 +1,16 @@ +from deepinv.models import UNet + +class MRIUNet(UNet): + def __init__(self, in_channels, out_channels, scales=3, batch_norm=False): + self.name = "MRIUNet" + self.in_channels = in_channels + + super().__init__(in_channels=in_channels, out_channels=out_channels, scales=scales, batch_norm=batch_norm) + + def forward(self, x, sigma=None, **kwargs): + # Reshape for MRI specific processing + x = x.reshape(1, self.in_channels, x.shape[3], x.shape[4]) + + x = super().forward(x, sigma=sigma, **kwargs) + + return x diff --git a/benchmark_utils/denoiser_2c.py b/benchmark_utils/denoiser_2c.py new file mode 100644 index 0000000..16eae4e --- /dev/null +++ b/benchmark_utils/denoiser_2c.py @@ -0,0 +1,20 @@ + +import torch +from deepinv.models import DRUNet +from deepinv.models import Denoiser + +class Denoiser_2c(Denoiser): + def __init__(self, device): + super(Denoiser_2c, self).__init__() + self.model_c1 = DRUNet(in_channels=1, out_channels=1, pretrained="download", device=device) + self.model_c2 = DRUNet(in_channels=1, out_channels=1, pretrained="download", device=device) + + def forward(self, y, sigma): + y1, y2 = torch.split(y, 1, dim=1) + + x_hat_1 = self.model_c1(y1, sigma=sigma) + x_hat_2 = self.model_c2(y2, sigma=sigma) + + x_hat = torch.cat([x_hat_1, x_hat_2], dim=1) + + return x_hat diff --git a/benchmark_utils/fastmri_dataset.py b/benchmark_utils/fastmri_dataset.py new file mode 100644 index 0000000..a4d853c --- /dev/null +++ b/benchmark_utils/fastmri_dataset.py @@ -0,0 +1,49 @@ +import torch +from torch.utils.data import Dataset +from deepinv.datasets import FastMRISliceDataset +import torch.nn.functional as F +import deepinv as dinv + +class FastMRIDataset(Dataset): + def __init__(self, dataset: FastMRISliceDataset, mask, max_coils=32): + self.dataset = dataset + self.max_coils = max_coils + self.mask = mask + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + x, y = self.dataset[idx] + x, y = x.to(device=self.mask.device), y.to(device=self.mask.device) + + # Pad the width + target_width = 400 + pad_total = target_width - y.shape[3] + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + y = F.pad(y, (pad_left, pad_right, 0, 0), mode='constant', value=0) + + # Pad the height + target_height = 700 + pad_total = target_height - y.shape[2] + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + y = F.pad(y, (0, 0, pad_left, pad_right), mode='constant', value=0) + + # Transform the mask to match the kspace shape + mask = self.mask.repeat(y.shape[0], y.shape[1], 1, 1) + + # Apply the mask to the k-space data + y = y * mask + + # Add an imaginary part of zeros + x = torch.cat([x, torch.zeros_like(x)], dim=0) + + # Pad the coil dimension if necessary + coil_dim = y.shape[1] + if coil_dim < self.max_coils: + pad_size = self.max_coils - coil_dim + y = F.pad(y, (0, 0, 0, 0, 0, pad_size)) + + return x, y diff --git a/benchmark_utils/hugging_face_torch_dataset.py b/benchmark_utils/hugging_face_torch_dataset.py index 95edc10..a70bb51 100644 --- a/benchmark_utils/hugging_face_torch_dataset.py +++ b/benchmark_utils/hugging_face_torch_dataset.py @@ -2,19 +2,26 @@ class HuggingFaceTorchDataset(torch.utils.data.Dataset): - def __init__(self, hf_dataset, key, transform=None): + def __init__(self, hf_dataset, key, physics, device, transform=None): self.hf_dataset = hf_dataset self.transform = transform self.key = key + self.device = device + self.physics = physics def __len__(self): return len(self.hf_dataset) def __getitem__(self, idx): sample = self.hf_dataset[idx] - image = sample[self.key] # Image PIL + x = sample[self.key] # Image PIL if self.transform: - image = self.transform(image) + x = self.transform(x) + + x = x.to(self.device) - return image + y = self.physics(x.unsqueeze(0)) + y = y.squeeze(0) + + return x, y diff --git a/benchmark_utils/image_dataset.py b/benchmark_utils/image_dataset.py index 524ac74..f24d214 100644 --- a/benchmark_utils/image_dataset.py +++ b/benchmark_utils/image_dataset.py @@ -1,5 +1,7 @@ import os import random +import deepinv as dinv +import torch.nn.functional as F from torch.utils.data import Dataset from typing import Callable @@ -9,9 +11,13 @@ class ImageDataset(Dataset): def __init__(self, folder: str, + physics: dinv.physics.Physics, + device: str, transform: Callable = None, - num_images=None): + num_images=None,): self.folder = folder + self.physics = physics + self.device = device self.transform = transform self.files = [f for f in os.listdir(folder) if f.endswith(( '.png', '.jpg', '.jpeg'))] @@ -25,9 +31,27 @@ def __len__(self): def __getitem__(self, idx): img_name = os.path.join(self.folder, self.files[idx]) - image = Image.open(img_name) + x = Image.open(img_name) if self.transform: - image = self.transform(image) - - return image + x = self.transform(x) + + x = x.to(self.device) + + y = self.physics(x.unsqueeze(0)) + y = y.squeeze(0) + + #_, x_h, x_w = x.shape + #_, y_h, y_w = y.shape + + #diff_h = x_h - y_h + #diff_w = x_w - y_w + + #pad_top = diff_h // 2 + #pad_bottom = diff_h - pad_top + #pad_left = diff_w // 2 + #pad_right = diff_w - pad_left + + #y = F.pad(y, pad=(pad_left, pad_right, pad_top, pad_bottom), value=0) + + return x, y diff --git a/benchmark_utils/metrics.py b/benchmark_utils/metrics.py new file mode 100644 index 0000000..9024aea --- /dev/null +++ b/benchmark_utils/metrics.py @@ -0,0 +1,30 @@ +import deepinv as dinv + +class CustomMSE(dinv.metric.MSE): + + transform = lambda x: x + + def forward(self, x_net=None, x=None, *args, **kwargs): + return super().forward(self.transform(x_net), x, *args, **kwargs) + + +class CustomPSNR(dinv.metric.PSNR): + + transform = lambda x: x + + def forward(self, x_net=None, x=None, *args, **kwargs): + return super().forward(self.transform(x_net), x, *args, **kwargs) + +class CustomSSIM(dinv.metric.SSIM): + + transform = lambda x: x + + def forward(self, x_net=None, x=None, *args, **kwargs): + return super().forward(self.transform(x_net), x, *args, **kwargs) + +class CustomLPIPS(dinv.metric.LPIPS): + + transform = lambda x: x + + def forward(self, x_net=None, x=None, *args, **kwargs): + return super().forward(self.transform(x_net), x, *args, **kwargs) diff --git a/config.yml b/config.yml index e16d5c7..179512d 100644 --- a/config.yml +++ b/config.yml @@ -1,5 +1,7 @@ -data_home: /Users/melvinenargeot/Data/benchmark_inverse_problems +data_home: /home/mind/mnargeot/benchmarks/benchmark_inverse_problems/data_tmp data_paths: generated_datasets: generated_datasets generated_trainings: generated_training - BSD500: BSD500/BSR/BSDS500/data/images \ No newline at end of file + BSD500: /data/parietal/store3/data/BSD500 + fastmri_train: /data/parietal/store3/data/fastMRI-multicoil/multicoil_train + fastmri_test: /data/parietal/store3/data/fastMRI-multicoil/multicoil_val diff --git a/datasets/bsd500_bsd20.py b/datasets/bsd500_bsd20.py index 26150bf..21b6dbe 100644 --- a/datasets/bsd500_bsd20.py +++ b/datasets/bsd500_bsd20.py @@ -5,7 +5,12 @@ import torch from torchvision import transforms from benchmark_utils.image_dataset import ImageDataset - from deepinv.physics import Downsampling, Denoising, GaussianNoise + from deepinv.physics import ( + Downsampling, + Denoising, + GaussianNoise, + Demosaicing + ) from deepinv.physics.generator import MotionBlurGenerator @@ -17,7 +22,9 @@ class Dataset(BaseDataset): 'task': ['denoising', 'gaussian-debluring', 'motion-debluring', - 'SRx4'], + 'SRx4', + 'inpainting', + 'demosaicing'], 'img_size': [256], } @@ -29,9 +36,10 @@ def get_data(self): dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu" n_channels = 3 + img_size = (n_channels, self.img_size, self.img_size) if self.task == "denoising": - noise_level_img = 0.03 + noise_level_img = 0.1 physics = Denoising(GaussianNoise(sigma=noise_level_img)) elif self.task == "gaussian-debluring": filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) @@ -39,7 +47,7 @@ def get_data(self): n_channels = 3 physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filter_torch, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), device=device @@ -55,17 +63,22 @@ def get_data(self): filters = motion_generator.step(batch_size=1) physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filters["filter"], device=device ) elif self.task == "SRx4": - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), + physics = Downsampling(img_size=img_size, filter="bicubic", factor=4, device=device) + elif self.task == "inpainting": + physics = dinv.physics.Inpainting(img_size, + mask=0.7, + device=device) + elif self.task == "demosaicing": + physics = Demosaicing(img_size=img_size, + device=device) else: raise Exception("Unknown task") @@ -76,32 +89,36 @@ def get_data(self): train_dataset = ImageDataset( config.get_data_path("BSD500") / "train", + physics, + device, transform=transform ) test_dataset = ImageDataset( config.get_data_path("BSD500") / "val", + physics, + device, transform=transform, num_images=20 ) - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, - physics=physics, - save_dir=config.get_data_path( - key="generated_datasets" - ) / "bsd500_bsd20", - dataset_filename=self.task, - device=device - ) - - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=True - ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=False - ) + #dinv_dataset_path = dinv.datasets.generate_dataset( + # train_dataset=train_dataset, + # test_dataset=test_dataset, + # physics=physics, + # save_dir=config.get_data_path( + # key="generated_datasets" + # ) / "bsd500_bsd20", + # dataset_filename=self.task, + # device=device + #) + + #train_dataset = dinv.datasets.HDF5Dataset( + # path=dinv_dataset_path, train=True + #) + #test_dataset = dinv.datasets.HDF5Dataset( + # path=dinv_dataset_path, train=False + #) x, y = train_dataset[0] dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) @@ -114,5 +131,6 @@ def get_data(self): test_dataset=test_dataset, physics=physics, dataset_name="BSD68", - task_name=self.task + task_name=self.task, + image_size=img_size ) diff --git a/datasets/bsd500_cbsd68.py b/datasets/bsd500_cbsd68.py index ccfab17..72a6944 100644 --- a/datasets/bsd500_cbsd68.py +++ b/datasets/bsd500_cbsd68.py @@ -9,7 +9,12 @@ from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Denoising, GaussianNoise, Downsampling + from deepinv.physics import ( + Denoising, + GaussianNoise, + Downsampling, + Demosaicing + ) from deepinv.physics.generator import MotionBlurGenerator @@ -21,7 +26,9 @@ class Dataset(BaseDataset): 'task': ['denoising', 'gaussian-debluring', 'motion-debluring', - 'SRx4'], + 'SRx4', + 'inpainting', + 'demosaicing'], 'img_size': [256], } @@ -31,24 +38,25 @@ def get_data(self): # TODO: Remove device = ( dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu" + + n_channels = 3 + img_size = (n_channels, self.img_size, self.img_size) if self.task == "denoising": - noise_level_img = 0.03 + noise_level_img = 0.1 physics = Denoising(GaussianNoise(sigma=noise_level_img)) elif self.task == "gaussian-debluring": filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) noise_level_img = 0.03 - n_channels = 3 physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filter_torch, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), device=device ) elif self.task == "motion-debluring": psf_size = 31 - n_channels = 3 motion_generator = MotionBlurGenerator( (psf_size, psf_size), device=device @@ -57,18 +65,22 @@ def get_data(self): filters = motion_generator.step(batch_size=1) physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filters["filter"], device=device ) elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), + physics = Downsampling(img_size=img_size, filter="bicubic", factor=4, device=device) + elif self.task == "inpainting": + physics = dinv.physics.Inpainting(img_size, + mask=0.7, + device=device) + elif self.task == "demosaicing": + physics = Demosaicing(img_size=img_size, + device=device) else: raise Exception("Unknown task") @@ -78,31 +90,38 @@ def get_data(self): ]) train_dataset = ImageDataset( - config.get_data_path("BSD500") / "train", transform=transform + config.get_data_path("BSD500") / "train", + physics, + device, + transform=transform, ) dataset_cbsd68 = load_dataset("deepinv/CBSD68") test_dataset = HuggingFaceTorchDataset( - dataset_cbsd68["train"], key="png", transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, + dataset_cbsd68["train"], + key="png", physics=physics, - save_dir=config.get_data_path( - key="generated_datasets" - ) / "bsd500_cbsd68", - dataset_filename=self.task, - device=device + device=device, + transform=transform ) - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=True - ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=False - ) + #dinv_dataset_path = dinv.datasets.generate_dataset( + # train_dataset=train_dataset, + # test_dataset=test_dataset, + # physics=physics, + # save_dir=config.get_data_path( + # key="generated_datasets" + # ) / "bsd500_cbsd68", + # dataset_filename=self.task, + # device=device + #) + + #train_dataset = dinv.datasets.HDF5Dataset( + # path=dinv_dataset_path, train=True + #) + #test_dataset = dinv.datasets.HDF5Dataset( + # path=dinv_dataset_path, train=False + #) x, y = train_dataset[0] dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) @@ -115,5 +134,6 @@ def get_data(self): test_dataset=test_dataset, physics=physics, dataset_name="BSD68", - task_name=self.task + task_name=self.task, + image_size=img_size ) diff --git a/datasets/bsd500_imnet100.py b/datasets/bsd500_imnet100.py index 3b277fc..428520b 100644 --- a/datasets/bsd500_imnet100.py +++ b/datasets/bsd500_imnet100.py @@ -8,7 +8,12 @@ from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Downsampling, Denoising, GaussianNoise + from deepinv.physics import ( + Downsampling, + Denoising, + GaussianNoise, + Demosaicing + ) from deepinv.physics.generator import MotionBlurGenerator from datasets import load_dataset @@ -21,7 +26,9 @@ class Dataset(BaseDataset): 'task': ['denoising', 'gaussian-debluring', 'motion-debluring', - 'SRx4'], + 'SRx4', + 'inpainting', + 'demosaicing'], 'img_size': [256], } @@ -31,24 +38,25 @@ def get_data(self): # TODO: Remove device = ( dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu" + + n_channels = 3 + img_size = (n_channels, self.img_size, self.img_size) if self.task == "denoising": - noise_level_img = 0.03 + noise_level_img = 0.1 physics = Denoising(GaussianNoise(sigma=noise_level_img)) elif self.task == "gaussian-debluring": filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) noise_level_img = 0.03 - n_channels = 3 physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filter_torch, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), device=device ) elif self.task == "motion-debluring": psf_size = 31 - n_channels = 3 motion_generator = MotionBlurGenerator( (psf_size, psf_size), device=device @@ -57,18 +65,22 @@ def get_data(self): filters = motion_generator.step(batch_size=1) physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filters["filter"], device=device ) elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), + physics = Downsampling(img_size=img_size, filter="bicubic", factor=4, device=device) + elif self.task == "inpainting": + physics = dinv.physics.Inpainting(img_size, + mask=0.7, + device=device) + elif self.task == "demosaicing": + physics = Demosaicing(img_size=img_size, + device=device) else: raise Exception("Unknown task") @@ -79,6 +91,8 @@ def get_data(self): train_dataset = ImageDataset( config.get_data_path("BSD500") / "train", + physics=physics, + device=device, transform=transform ) @@ -86,26 +100,28 @@ def get_data(self): test_dataset = HuggingFaceTorchDataset( dataset_miniImnet100["validation"], key="image", - transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, physics=physics, - save_dir=config.get_data_path( - key="generated_datasets" - ) / "bsd500_imnet100", - dataset_filename=self.task, - device=device + device=device, + transform=transform ) - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=True - ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=False - ) + #dinv_dataset_path = dinv.datasets.generate_dataset( + # train_dataset=train_dataset, + # test_dataset=test_dataset, + # physics=physics, + # save_dir=config.get_data_path( + # key="generated_datasets" + # ) / "bsd500_imnet100", + # dataset_filename=self.task, + # device=device + #) + + #train_dataset = dinv.datasets.HDF5Dataset( + # path=dinv_dataset_path, train=True + #) + #test_dataset = dinv.datasets.HDF5Dataset( + # path=dinv_dataset_path, train=False + #) x, y = train_dataset[0] dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) @@ -118,5 +134,6 @@ def get_data(self): test_dataset=test_dataset, physics=physics, dataset_name="BSD68", - task_name=self.task + task_name=self.task, + image_size=img_size ) diff --git a/datasets/cbsd68_set3c.py b/datasets/cbsd68_set3c.py index 325c160..ca9101e 100644 --- a/datasets/cbsd68_set3c.py +++ b/datasets/cbsd68_set3c.py @@ -8,7 +8,12 @@ from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Denoising, GaussianNoise, Downsampling + from deepinv.physics import ( + Denoising, + GaussianNoise, + Downsampling, + Demosaicing, + ) from deepinv.physics.generator import MotionBlurGenerator @@ -20,7 +25,9 @@ class Dataset(BaseDataset): 'task': ['denoising', 'gaussian-debluring', 'motion-debluring', - 'SRx4'], + 'SRx4', + 'inpainting', + 'demosaicing'], 'img_size': [256], } @@ -32,23 +39,24 @@ def get_data(self): dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) + n_channels = 3 + image_size = (n_channels, self.img_size, self.img_size) + if self.task == "denoising": - noise_level_img = 0.03 + noise_level_img = 0.1 physics = Denoising(GaussianNoise(sigma=noise_level_img)) elif self.task == "gaussian-debluring": filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) noise_level_img = 0.03 - n_channels = 3 physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=image_size, filter=filter_torch, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), device=device ) elif self.task == "motion-debluring": psf_size = 31 - n_channels = 3 motion_generator = MotionBlurGenerator( (psf_size, psf_size), device=device @@ -57,18 +65,22 @@ def get_data(self): filters = motion_generator.step(batch_size=1) physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=image_size, filter=filters["filter"], device=device ) elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), + physics = Downsampling(img_size=image_size, filter="bicubic", factor=4, device=device) + elif self.task == "inpainting": + physics = dinv.physics.Inpainting(image_size, + mask=0.7, + device=device) + elif self.task == "demosaicing": + physics = Demosaicing(img_size=image_size, + device=device) else: raise Exception("Unknown task") @@ -79,33 +91,41 @@ def get_data(self): dataset_CBSD68 = load_dataset("deepinv/CBSD68") train_dataset = HuggingFaceTorchDataset( - dataset_CBSD68["train"], key="png", transform=transform + dataset_CBSD68["train"], + key="png", + physics=physics, + device=device, + transform=transform ) dataset_Set3c = load_dataset("deepinv/set3c") test_dataset = HuggingFaceTorchDataset( - dataset_Set3c["train"], key="image", transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, + dataset_Set3c["train"], + key="image", physics=physics, - save_dir=config.get_data_path( - key="generated_datasets" - ) / "sbsd68_set3c", - dataset_filename=self.task, - device=device + device=device, + transform=transform ) - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, - train=True - ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, - train=False - ) + #dinv_dataset_path = dinv.datasets.generate_dataset( + # train_dataset=train_dataset, + # test_dataset=test_dataset, + # physics=physics, + # save_dir=config.get_data_path( + # key="generated_datasets" + # ) / "sbsd68_set3c", + # dataset_filename=self.task, + # device=device + #) + + #train_dataset = dinv.datasets.HDF5Dataset( + # path=dinv_dataset_path, + # train=True + #) + #test_dataset = dinv.datasets.HDF5Dataset( + # path=dinv_dataset_path, + # train=False + #) x, y = train_dataset[0] dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) @@ -118,5 +138,6 @@ def get_data(self): test_dataset=test_dataset, physics=physics, dataset_name="Set3c", - task_name=self.task + task_name=self.task, + image_size=image_size ) diff --git a/datasets/fastmri.py b/datasets/fastmri.py new file mode 100644 index 0000000..56cb900 --- /dev/null +++ b/datasets/fastmri.py @@ -0,0 +1,55 @@ +from benchopt import BaseDataset, safe_import_context, config + +with safe_import_context() as import_ctx: + import deepinv as dinv + import torch, torchvision + from torch.utils.data import DataLoader + from benchmark_utils.fastmri_dataset import FastMRIDataset + +MAX_COILS = 32 # Maximum number of coils to pad to +KSPACE_PADDED_SIZE = (700, 400) # K-space size for FastMRI dataset + +class Dataset(BaseDataset): + name = "FastMRI" + + parameters = {} + + def get_data(self): + device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" + rng = torch.Generator(device=device).manual_seed(0) + + physics_generator = dinv.physics.generator.GaussianMaskGenerator( + img_size=KSPACE_PADDED_SIZE, acceleration=4, rng=rng, device=device + ) + mask = physics_generator.step( + batch_size=1, img_size=KSPACE_PADDED_SIZE + )["mask"] + + train_dataset = FastMRIDataset(dinv.datasets.FastMRISliceDataset( + config.get_data_path(key="fastmri_train"), slice_index="middle" + ), mask, MAX_COILS) + + test_dataset = FastMRIDataset(dinv.datasets.FastMRISliceDataset( + config.get_data_path(key="fastmri_test"), slice_index="middle" + ), mask, MAX_COILS) + + + x, y = train_dataset[0] + + img_size, kspace_shape = x.shape[-2:], KSPACE_PADDED_SIZE + + physics = dinv.physics.MultiCoilMRI( + img_size=img_size, + mask=mask, + coil_maps=torch.ones((MAX_COILS,) + kspace_shape, dtype=torch.complex64), + device=device, + ) + + return dict( + train_dataset=train_dataset, + test_dataset=test_dataset, + physics=physics, + dataset_name="FastMRI", + task_name="MRI", + image_size=y.shape + ) diff --git a/datasets/simulated.py b/datasets/simulated.py index 27befd3..c4c7fb5 100644 --- a/datasets/simulated.py +++ b/datasets/simulated.py @@ -25,5 +25,6 @@ def get_data(self): test_dataset=test_dataset, physics=Denoising(GaussianNoise(sigma=0.03)), dataset_name="simulated", - task_name="test" + task_name="test", + image_size=(3, 32, 32) ) diff --git a/objective.py b/objective.py index 5ed8c3b..00df7f2 100644 --- a/objective.py +++ b/objective.py @@ -7,6 +7,11 @@ import torch from torch.utils.data import DataLoader import deepinv as dinv + import torchvision + import torch.nn.functional as F + from benchmark_utils.metrics import CustomPSNR, CustomSSIM, CustomLPIPS + from tqdm import tqdm + import time # The benchmark objective must be named `Objective` and @@ -43,7 +48,8 @@ def set_data(self, test_dataset, physics, dataset_name, - task_name): + task_name, + image_size): # The keyword arguments of this function are the keys of the dictionary # returned by `Dataset.get_data`. This defines the benchmark's # API to pass data. This is customizable for each benchmark. @@ -52,6 +58,7 @@ def set_data(self, self.physics = physics self.dataset_name = dataset_name self.task_name = task_name + self.image_size = image_size def evaluate_result(self, model, model_name, device): # The keyword arguments of this function are the keys of the @@ -59,49 +66,98 @@ def evaluate_result(self, model, model_name, device): # benchmark's API to pass solvers' result. This is customizable for # each benchmark. - batch_size = 2 + batch_size = 1 test_dataloader = DataLoader( self.test_dataset, batch_size=batch_size, shuffle=False ) - if isinstance(model, dinv.models.DeepImagePrior): - psnr = [] - ssim = [] - lpips = [] - - for x, y in test_dataloader: - x, y = x.to(device), y.to(device) - x_hat = torch.cat([ + # DeepImagePrior use images one by one, thus we can't use dinv.test + #if isinstance(model, dinv.models.DeepImagePrior): + psnr = [] + ssim = [] + lpips = [] + times = [] + + for x, y in tqdm(test_dataloader, desc=f"Evaluating {model_name}"): + x, y = x.to(device), y.to(device) + + if isinstance(model, dinv.models.DeepImagePrior): + start = time.time() + x_hat = [ model(y_i[None], self.physics) for y_i in y - ]) + ] + exec_time = time.time() - start + x_hat = torch.cat(x_hat) + else: + if type(self.physics) is dinv.physics.blur.Downsampling and model_name == 'U-Net': + _, _, x_h, x_w = x.shape + _, _, y_h, y_w = y.shape + + diff_h = x_h - y_h + diff_w = x_w - y_w + + pad_top = diff_h // 2 + pad_bottom = diff_h - pad_top + pad_left = diff_w // 2 + pad_right = diff_w - pad_left + + y = F.pad(y, pad=(pad_left, pad_right, pad_top, pad_bottom), value=0) + + start = time.time() + x_hat = model(y, self.physics) + exec_time = time.time() - start + + times.append(exec_time) + + if (self.dataset_name == 'FastMRI'): + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.CenterCrop(x.shape[-2:]), + dinv.metric.functional.complex_abs, + ] + ) + + CustomPSNR.transform = transform + + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.CenterCrop(x.shape[-2:]), + ] + ) + + CustomSSIM.transform = transform + + psnr.append(CustomPSNR()(x_hat, x)) + else: psnr.append(dinv.metric.PSNR()(x_hat, x)) ssim.append(dinv.metric.SSIM()(x_hat, x)) lpips.append(dinv.metric.LPIPS(device=device)(x_hat, x)) - psnr = torch.mean(torch.cat(psnr)).item() + psnr = torch.mean(torch.cat(psnr)).item() + times = torch.mean(torch.tensor(times)).item() + + results = dict(PSNR=psnr) + + if self.dataset_name != 'FastMRI': ssim = torch.mean(torch.cat(ssim)).item() lpips = torch.mean(torch.cat(lpips)).item() + results['SSIM'] = ssim + results['LPIPS'] = lpips + + results['Time'] = times - results = dict(PSNR=psnr, SSIM=ssim, LPIPS=lpips) - else: - results = dinv.test( - model, - test_dataloader, - self.physics, - metrics=[dinv.metric.PSNR(), - dinv.metric.SSIM(), - dinv.metric.LPIPS(device=device)], - device=device - ) - - # This method can return many metrics in a dictionary. One of these - # metrics needs to be `value` for convergence detection purposes. - return dict( + values = dict( value=results["PSNR"], - ssim=results["SSIM"], - lpips=results["LPIPS"] ) + if self.dataset_name != 'FastMRI': + values['ssim'] = results["SSIM"] + values['lpips'] = results["LPIPS"] + + values['time'] = results["Time"] + + return values + def get_one_result(self): # Return one solution. The return value should be an object compatible # with `self.evaluate_result`. This is mainly for testing purposes. @@ -114,4 +170,8 @@ def get_objective(self): # for `Solver.set_objective`. This defines the # benchmark's API for passing the objective to the solver. # It is customizable for each benchmark. - return dict(train_dataset=self.train_dataset, physics=self.physics) + + return dict(train_dataset=self.train_dataset, + physics=self.physics, + image_size=self.image_size, + dataset_name=self.dataset_name,) diff --git a/solvers/ddrm.py b/solvers/ddrm.py new file mode 100644 index 0000000..dc51cec --- /dev/null +++ b/solvers/ddrm.py @@ -0,0 +1,48 @@ +from benchopt import BaseSolver, safe_import_context + +with safe_import_context() as import_ctx: + import torch + from torch.utils.data import DataLoader + import deepinv as dinv + import numpy as np + + +class Solver(BaseSolver): + name = 'DDRM' + + parameters = {} + + sampling_strategy = 'run_once' + + requirements = [] + + def set_objective(self, train_dataset, physics, image_size): + batch_size = 2 + self.train_dataloader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=False + ) + self.device = ( + dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" + ) + self.physics = physics + + def run(self, n_iter): + denoiser = dinv.models.DRUNet(pretrained="download").to(self.device) + + sigmas = (np.linspace(1,0, 100) + if torch.cuda.is_available() + else np.linspace(1, 0, 10)) + + self.model = dinv.sampling.DDRM( + denoiser=denoiser, + etab=1.0, + sigmas=sigmas, + verbose=True + ) + self.model.eval() + + def get_result(self): + return dict(model=self.model, model_name="DiffPIR", device=self.device) + + def skip(self, train_dataset, physics, image_size, dataset_name): + return True, "Not yet implemented." diff --git a/solvers/diffpir.py b/solvers/diffpir.py index a3bcaf2..ff7f7e4 100644 --- a/solvers/diffpir.py +++ b/solvers/diffpir.py @@ -4,6 +4,7 @@ import torch from torch.utils.data import DataLoader import deepinv as dinv + from benchmark_utils.denoiser_2c import Denoiser_2c class Solver(BaseSolver): @@ -15,7 +16,7 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): + def set_objective(self, train_dataset, physics, image_size, dataset_name): batch_size = 2 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False @@ -25,14 +26,20 @@ def set_objective(self, train_dataset, physics): ) self.physics = physics + self.image_size = image_size + def run(self, n_iter): - denoiser = dinv.models.DRUNet(pretrained="download").to(self.device) + if self.image_size[0] == 2: + denoiser = Denoiser_2c(device=self.device) + else: + denoiser = dinv.models.DRUNet(pretrained="download").to(self.device) self.model = dinv.sampling.DiffPIR( model=denoiser, data_fidelity=dinv.optim.data_fidelity.L2(), device=self.device ) + self.model.eval() def get_result(self): diff --git a/solvers/dip.py b/solvers/dip.py index 3a41c91..7566887 100644 --- a/solvers/dip.py +++ b/solvers/dip.py @@ -16,9 +16,9 @@ class Solver(BaseSolver): requirements = ["optuna"] - def set_objective(self, train_dataset, physics): + def set_objective(self, train_dataset, physics, image_size, dataset_name): self.train_dataset = train_dataset - batch_size = 32 + batch_size = 1 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) @@ -26,24 +26,24 @@ def set_objective(self, train_dataset, physics): dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) self.physics = physics.to(self.device) + self.image_size = image_size def run(self, n_iter): def objective(trial): lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True) iterations = trial.suggest_int('iterations', 50, 500, log=True) - # TODO: Remove - # iterations = 5 - model = self.get_model(lr, iterations) psnr = [] for x, y in self.train_dataloader: x, y = x.to(self.device), y.to(self.device) + x_hat = torch.cat([ model(y_i[None], self.physics) for y_i in y ]) + psnr.append(dinv.metric.PSNR()(x_hat, x)) psnr = torch.mean(torch.cat(psnr)).item() @@ -51,13 +51,11 @@ def objective(trial): return psnr study = optuna.create_study(direction='maximize') - study.optimize(objective, n_trials=1) + study.optimize(objective, n_trials=3) best_trial = study.best_trial best_params = best_trial.params - # TODO : replace 5 by best_params['iterations']) - # self.model = self.get_model(best_params['lr'], 5) self.model = self.get_model( best_params['lr'], best_params['iterations'] diff --git a/solvers/dpir.py b/solvers/dpir.py index 5669eed..19f350b 100644 --- a/solvers/dpir.py +++ b/solvers/dpir.py @@ -5,7 +5,15 @@ from torch.utils.data import DataLoader import deepinv as dinv import numpy as np - + import torchvision + from deepinv.optim import BaseOptim + from deepinv.optim.prior import PnP + from deepinv.optim.data_fidelity import L2 + from deepinv.optim.optimizers import create_iterator + from deepinv.optim.dpir import get_DPIR_params + from benchmark_utils.denoiser_2c import Denoiser_2c + from benchmark_utils.metrics import CustomPSNR + from tqdm import tqdm class Solver(BaseSolver): name = 'DPIR' @@ -16,8 +24,8 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): - batch_size = 2 + def set_objective(self, train_dataset, physics, image_size, dataset_name): + batch_size = 1 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) @@ -25,27 +33,69 @@ def set_objective(self, train_dataset, physics): dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) self.physics = physics + self.image_size = image_size + self.dataset_name = dataset_name def run(self, n_iter): best_sigma = 0 best_psnr = 0 + + # If the number of channels is 2 we use a custom DPIR solver + if self.image_size[0] == 2: + model_class = DPIR_2C + else: + model_class = dinv.optim.DPIR + + # If the number of channels is different from 1 or 3 + # then we can't use pretrained DRUNet for sigma in np.linspace(0.01, 0.1, 10): - model = dinv.optim.DPIR(sigma=sigma, device=self.device) + model = model_class(sigma=sigma, device=self.device) + + psnr = [] + + for x, y in tqdm(self.train_dataloader, desc=f"DPIR : Looking for the best sigma"): + x, y = x.to(self.device), y.to(self.device) + + x_hat = model(y, self.physics) + + if (self.dataset_name == 'FastMRI'): + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.CenterCrop(x.shape[-2:]), + dinv.metric.functional.complex_abs, + ] + ) - results = dinv.test( - model, - self.train_dataloader, - self.physics, - metrics=[dinv.metric.PSNR(), dinv.metric.SSIM()], - device=self.device - ) + CustomPSNR.transform = transform - if results["PSNR"] > best_psnr: + psnr.append(CustomPSNR()(x_hat, x)) + else: + psnr.append(dinv.metric.PSNR()(x_hat, x)) + + psnr = torch.mean(torch.cat(psnr)).item() + + if psnr > best_psnr: best_sigma = sigma - best_psnr = results["PSNR"] + best_psnr = psnr - self.model = dinv.optim.DPIR(sigma=best_sigma, device=self.device) + self.model = model_class(sigma=best_sigma, device=self.device) self.model.eval() def get_result(self): return dict(model=self.model, model_name="DPIR", device=self.device) + + +# Custom DPIR solver with 2 channels +class DPIR_2C(BaseOptim): + def __init__(self, sigma=0.1, device="cuda"): + prior = PnP(denoiser=Denoiser_2c(device=device)) + sigma_denoiser, stepsize, max_iter = get_DPIR_params(sigma) + params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser} + super(DPIR_2C, self).__init__( + create_iterator("HQS", prior=prior, F_fn=None, g_first=False), + max_iter=max_iter, + data_fidelity=L2(), + prior=prior, + early_stop=False, + params_algo=params_algo, + ) diff --git a/solvers/ifft2.py b/solvers/ifft2.py new file mode 100644 index 0000000..70205b9 --- /dev/null +++ b/solvers/ifft2.py @@ -0,0 +1,44 @@ +from benchopt import BaseSolver, safe_import_context + +with safe_import_context() as import_ctx: + import torch + from torch.utils.data import DataLoader + import deepinv as dinv + import numpy as np + + +class Solver(BaseSolver): + name = 'IFFT2' + + parameters = {} + + sampling_strategy = 'run_once' + + requirements = [] + + def set_objective(self, train_dataset, physics, image_size, dataset_name): + batch_size = 2 + self.train_dataloader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=False + ) + self.device = ( + dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" + ) + self.physics = physics + self.image_size = image_size + self.dataset_name = dataset_name + + def run(self, n_iter): + def model(y, physics): + return physics.A_dagger(y) + + self.model = model + + def get_result(self): + return dict(model=self.model, model_name="IFFT2", device=self.device) + + def skip(self, **objective_dict): + if isinstance(objective_dict['physics'], dinv.physics.mri.MultiCoilMRI): + return False, None + + return True, "This solver is only available for MRI dataset" diff --git a/solvers/u-net.py b/solvers/u-net.py index b3e9534..ba54978 100644 --- a/solvers/u-net.py +++ b/solvers/u-net.py @@ -3,8 +3,12 @@ with safe_import_context() as import_ctx: import torch + import torch.nn.functional as F from torch.utils.data import DataLoader import deepinv as dinv + import torchvision + from benchmark_utils.metrics import CustomMSE, CustomPSNR + from benchmark_utils.custom_models import MRIUNet class Solver(BaseSolver): @@ -19,8 +23,8 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): - batch_size = 2 + def set_objective(self, train_dataset, physics, image_size, dataset_name): + batch_size = 1 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) @@ -28,43 +32,90 @@ def set_objective(self, train_dataset, physics): dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) self.physics = physics.to(self.device) + self.image_size = image_size + self.dataset_name = dataset_name def run(self, n_iter): epochs = 4 + + x, y = next(iter(self.train_dataloader)) - model = dinv.models.UNet( - in_channels=3, out_channels=3, scales=3, batch_norm=False - ).to(self.device) + if self.dataset_name == 'FastMRI': + model = MRIUNet( + in_channels=y.shape[1] * y.shape[2], out_channels=x.shape[1], scales=3, + batch_norm=False + ).to(self.device) + else: + model = dinv.models.UNet( + in_channels=y.shape[1], out_channels=x.shape[1], scales=4, + batch_norm=False + ).to(self.device) - verbose = True # print training information - wandb_vis = False # plot curves and images in Weight&Bias - - # choose training losses - losses = dinv.loss.SupLoss(metric=dinv.metric.MSE()) - - # choose optimizer and scheduler optimizer = torch.optim.Adam( model.parameters(), lr=self.lr, weight_decay=1e-8 ) scheduler = torch.optim.lr_scheduler.StepLR( - optimizer, step_size=int(epochs * 0.8) - ) - trainer = dinv.Trainer( - model, - device=self.device, - verbose=verbose, - wandb_vis=wandb_vis, - physics=self.physics, - epochs=epochs, - scheduler=scheduler, - losses=losses, - optimizer=optimizer, - show_progress_bar=True, - train_dataloader=self.train_dataloader, + optimizer, step_size=int(epochs * 0.7) ) - self.model = trainer.train() - self.model.eval() + # choose training losses + if self.dataset_name == 'FastMRI': + criterion = dinv.loss.SupLoss(metric=CustomMSE()) + else: + criterion = dinv.loss.SupLoss(metric=dinv.metric.MSE()) + + for epoch in range(epochs): + model.train() + running_loss = 0.0 + + for x, y in self.train_dataloader: + x, y = x.to(self.device), y.to(self.device) + + if type(self.physics) is dinv.physics.blur.Downsampling: + _, _, x_h, x_w = x.shape + _, _, y_h, y_w = y.shape + + diff_h = x_h - y_h + diff_w = x_w - y_w + + pad_top = diff_h // 2 + pad_bottom = diff_h - pad_top + pad_left = diff_w // 2 + pad_right = diff_w - pad_left + + y = F.pad(y, pad=(pad_left, pad_right, pad_top, pad_bottom), value=0) + + x_hat = model(y, self.physics) + + if self.dataset_name == 'FastMRI': + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.CenterCrop(x.shape[-2:]), + dinv.metric.functional.complex_abs, + ] + ) + criterion.metric.transform = transform + + #if type(self.physics) is dinv.physics.blur.Downsampling: + #breakpoint() + + loss = criterion(x_hat, x) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + + avg_loss = running_loss / len(self.train_dataloader) + print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}") + + scheduler.step() + + model.eval() + + self.model = model def get_result(self): - return dict(model=self.model, model_name="U-Net", device=self.device) + return dict(model=self.model, model_name=f"U-Net_{self.lr}", device=self.device) +