Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
12cd68b
WIP add FastMRI dataset and demosaicing task
Melvin-klein May 19, 2025
6a8108f
WIP
Melvin-klein May 26, 2025
80f232b
WIP Add ifft2 as baseline
Melvin-klein May 26, 2025
217436c
WIP skip ifft2 solver on datasets that are not FastMRI
Melvin-klein May 26, 2025
5b7602f
WIP
Melvin-klein Jun 5, 2025
65beb1a
WIP
Melvin-klein Jun 5, 2025
4525802
WIP : Add new denoiser for DPIR
Melvin-klein Jun 11, 2025
92e8791
Removed unused files
Melvin-klein Jun 11, 2025
a0227e5
Add DPIR_2C to handle imaginary images
Melvin-klein Jun 12, 2025
b8fa8ee
Refactor code
Melvin-klein Jun 12, 2025
2a1a435
Update DiffPIR for imaginary images
Melvin-klein Jun 12, 2025
5c1ea37
Every solvers run on SimpleFastMIR
Melvin-klein Jun 12, 2025
7442ff1
WIP: Change SimpleFastMRISliceDataset to FastMRISliceDataset
Melvin-klein Jun 16, 2025
c508509
WIP
Jun 18, 2025
27a3f07
WIP
Jun 18, 2025
4f9e189
WIP
Jul 17, 2025
0737ca6
WIP
Jul 18, 2025
386be6a
WIP
Melvin-klein Jul 18, 2025
ecc81ae
WIP
Melvin-klein Jul 19, 2025
8029157
WIP
Melvin-klein Jul 22, 2025
25b1375
UNet working
Melvin-klein Jul 23, 2025
ac88628
UNet, DPIR, DiffPIR working
Melvin-klein Jul 23, 2025
7a55efd
WIP
Melvin-klein Jul 23, 2025
be94a63
WIP
Melvin-klein Jul 29, 2025
23baf2b
WIP
Melvin-klein Aug 4, 2025
4af546b
Added inpainting, fix bugs
Melvin-klein Aug 25, 2025
65c1f65
Add inference time per degraded image to metrics
Melvin-klein Aug 27, 2025
32526b3
Fix U-Net solver's scheduler
Melvin-klein Aug 27, 2025
8ca02b8
Fix comments
Melvin-klein Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ benchopt.ini

.DS_Store
coverage.xml

tmp
16 changes: 16 additions & 0 deletions benchmark_utils/custom_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from deepinv.models import UNet

class MRIUNet(UNet):
def __init__(self, in_channels, out_channels, scales=3, batch_norm=False):
self.name = "MRIUNet"
self.in_channels = in_channels

super().__init__(in_channels=in_channels, out_channels=out_channels, scales=scales, batch_norm=batch_norm)

def forward(self, x, sigma=None, **kwargs):
# Reshape for MRI specific processing
x = x.reshape(1, self.in_channels, x.shape[3], x.shape[4])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it forcing the batch_size to be 1? why not x.shape[0]? Why do you need to reshape at all actually?


x = super().forward(x, sigma=sigma, **kwargs)

return x
20 changes: 20 additions & 0 deletions benchmark_utils/denoiser_2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

import torch
from deepinv.models import DRUNet
from deepinv.models import Denoiser

class Denoiser_2c(Denoiser):
def __init__(self, device):
super(Denoiser_2c, self).__init__()
self.model_c1 = DRUNet(in_channels=1, out_channels=1, pretrained="download", device=device)
self.model_c2 = DRUNet(in_channels=1, out_channels=1, pretrained="download", device=device)

def forward(self, y, sigma):
y1, y2 = torch.split(y, 1, dim=1)

x_hat_1 = self.model_c1(y1, sigma=sigma)
x_hat_2 = self.model_c2(y2, sigma=sigma)

x_hat = torch.cat([x_hat_1, x_hat_2], dim=1)

return x_hat
49 changes: 49 additions & 0 deletions benchmark_utils/fastmri_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch.utils.data import Dataset
from deepinv.datasets import FastMRISliceDataset
import torch.nn.functional as F
import deepinv as dinv

class FastMRIDataset(Dataset):
def __init__(self, dataset: FastMRISliceDataset, mask, max_coils=32):
self.dataset = dataset
self.max_coils = max_coils
self.mask = mask

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
x, y = self.dataset[idx]
x, y = x.to(device=self.mask.device), y.to(device=self.mask.device)

# Pad the width
target_width = 400
pad_total = target_width - y.shape[3]
pad_left = pad_total // 2
pad_right = pad_total - pad_left
y = F.pad(y, (pad_left, pad_right, 0, 0), mode='constant', value=0)

# Pad the height
target_height = 700
pad_total = target_height - y.shape[2]
pad_left = pad_total // 2
pad_right = pad_total - pad_left
y = F.pad(y, (0, 0, pad_left, pad_right), mode='constant', value=0)

# Transform the mask to match the kspace shape
mask = self.mask.repeat(y.shape[0], y.shape[1], 1, 1)

# Apply the mask to the k-space data
y = y * mask

# Add an imaginary part of zeros
x = torch.cat([x, torch.zeros_like(x)], dim=0)

# Pad the coil dimension if necessary
coil_dim = y.shape[1]
if coil_dim < self.max_coils:
pad_size = self.max_coils - coil_dim
y = F.pad(y, (0, 0, 0, 0, 0, pad_size))

return x, y
15 changes: 11 additions & 4 deletions benchmark_utils/hugging_face_torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@


class HuggingFaceTorchDataset(torch.utils.data.Dataset):
def __init__(self, hf_dataset, key, transform=None):
def __init__(self, hf_dataset, key, physics, device, transform=None):
self.hf_dataset = hf_dataset
self.transform = transform
self.key = key
self.device = device
self.physics = physics

def __len__(self):
return len(self.hf_dataset)

def __getitem__(self, idx):
sample = self.hf_dataset[idx]
image = sample[self.key] # Image PIL
x = sample[self.key] # Image PIL

if self.transform:
image = self.transform(image)
x = self.transform(x)

x = x.to(self.device)

return image
y = self.physics(x.unsqueeze(0))
y = y.squeeze(0)

return x, y
34 changes: 29 additions & 5 deletions benchmark_utils/image_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import random
import deepinv as dinv
import torch.nn.functional as F

from torch.utils.data import Dataset
from typing import Callable
Expand All @@ -9,9 +11,13 @@
class ImageDataset(Dataset):
def __init__(self,
folder: str,
physics: dinv.physics.Physics,
device: str,
transform: Callable = None,
num_images=None):
num_images=None,):
self.folder = folder
self.physics = physics
self.device = device
self.transform = transform
self.files = [f for f in os.listdir(folder) if f.endswith((
'.png', '.jpg', '.jpeg'))]
Expand All @@ -25,9 +31,27 @@ def __len__(self):

def __getitem__(self, idx):
img_name = os.path.join(self.folder, self.files[idx])
image = Image.open(img_name)
x = Image.open(img_name)

if self.transform:
image = self.transform(image)

return image
x = self.transform(x)

x = x.to(self.device)

y = self.physics(x.unsqueeze(0))
y = y.squeeze(0)

#_, x_h, x_w = x.shape
#_, y_h, y_w = y.shape

#diff_h = x_h - y_h
#diff_w = x_w - y_w

#pad_top = diff_h // 2
#pad_bottom = diff_h - pad_top
#pad_left = diff_w // 2
#pad_right = diff_w - pad_left

#y = F.pad(y, pad=(pad_left, pad_right, pad_top, pad_bottom), value=0)

return x, y
30 changes: 30 additions & 0 deletions benchmark_utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import deepinv as dinv

class CustomMSE(dinv.metric.MSE):

transform = lambda x: x

def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(self.transform(x_net), x, *args, **kwargs)


class CustomPSNR(dinv.metric.PSNR):

transform = lambda x: x

def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(self.transform(x_net), x, *args, **kwargs)

class CustomSSIM(dinv.metric.SSIM):

transform = lambda x: x

def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(self.transform(x_net), x, *args, **kwargs)

class CustomLPIPS(dinv.metric.LPIPS):

transform = lambda x: x

def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(self.transform(x_net), x, *args, **kwargs)
6 changes: 4 additions & 2 deletions config.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
data_home: /Users/melvinenargeot/Data/benchmark_inverse_problems
data_home: /home/mind/mnargeot/benchmarks/benchmark_inverse_problems/data_tmp
data_paths:
generated_datasets: generated_datasets
generated_trainings: generated_training
BSD500: BSD500/BSR/BSDS500/data/images
BSD500: /data/parietal/store3/data/BSD500
fastmri_train: /data/parietal/store3/data/fastMRI-multicoil/multicoil_train
fastmri_test: /data/parietal/store3/data/fastMRI-multicoil/multicoil_val
70 changes: 44 additions & 26 deletions datasets/bsd500_bsd20.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import torch
from torchvision import transforms
from benchmark_utils.image_dataset import ImageDataset
from deepinv.physics import Downsampling, Denoising, GaussianNoise
from deepinv.physics import (
Downsampling,
Denoising,
GaussianNoise,
Demosaicing
)
from deepinv.physics.generator import MotionBlurGenerator


Expand All @@ -17,7 +22,9 @@ class Dataset(BaseDataset):
'task': ['denoising',
'gaussian-debluring',
'motion-debluring',
'SRx4'],
'SRx4',
'inpainting',
'demosaicing'],
'img_size': [256],
}

Expand All @@ -29,17 +36,18 @@ def get_data(self):
dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu"

n_channels = 3
img_size = (n_channels, self.img_size, self.img_size)

if self.task == "denoising":
noise_level_img = 0.03
noise_level_img = 0.1
physics = Denoising(GaussianNoise(sigma=noise_level_img))
elif self.task == "gaussian-debluring":
filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3))
noise_level_img = 0.03
n_channels = 3

physics = dinv.physics.BlurFFT(
img_size=(n_channels, self.img_size, self.img_size),
img_size=img_size,
filter=filter_torch,
noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
device=device
Expand All @@ -55,17 +63,22 @@ def get_data(self):
filters = motion_generator.step(batch_size=1)

physics = dinv.physics.BlurFFT(
img_size=(n_channels, self.img_size, self.img_size),
img_size=img_size,
filter=filters["filter"],
device=device
)
elif self.task == "SRx4":
physics = Downsampling(img_size=(n_channels,
self.img_size,
self.img_size),
physics = Downsampling(img_size=img_size,
filter="bicubic",
factor=4,
device=device)
elif self.task == "inpainting":
physics = dinv.physics.Inpainting(img_size,
mask=0.7,
device=device)
elif self.task == "demosaicing":
physics = Demosaicing(img_size=img_size,
device=device)
else:
raise Exception("Unknown task")

Expand All @@ -76,32 +89,36 @@ def get_data(self):

train_dataset = ImageDataset(
config.get_data_path("BSD500") / "train",
physics,
device,
transform=transform
)

test_dataset = ImageDataset(
config.get_data_path("BSD500") / "val",
physics,
device,
transform=transform,
num_images=20
)

dinv_dataset_path = dinv.datasets.generate_dataset(
train_dataset=train_dataset,
test_dataset=test_dataset,
physics=physics,
save_dir=config.get_data_path(
key="generated_datasets"
) / "bsd500_bsd20",
dataset_filename=self.task,
device=device
)

train_dataset = dinv.datasets.HDF5Dataset(
path=dinv_dataset_path, train=True
)
test_dataset = dinv.datasets.HDF5Dataset(
path=dinv_dataset_path, train=False
)
#dinv_dataset_path = dinv.datasets.generate_dataset(
# train_dataset=train_dataset,
# test_dataset=test_dataset,
# physics=physics,
# save_dir=config.get_data_path(
# key="generated_datasets"
# ) / "bsd500_bsd20",
# dataset_filename=self.task,
# device=device
#)

#train_dataset = dinv.datasets.HDF5Dataset(
# path=dinv_dataset_path, train=True
#)
#test_dataset = dinv.datasets.HDF5Dataset(
# path=dinv_dataset_path, train=False
#)

x, y = train_dataset[0]
dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)])
Expand All @@ -114,5 +131,6 @@ def get_data(self):
test_dataset=test_dataset,
physics=physics,
dataset_name="BSD68",
task_name=self.task
task_name=self.task,
image_size=img_size
)
Loading
Loading