From 036ebe5d3716c97402b6458b5ffb072c323d1209 Mon Sep 17 00:00:00 2001 From: Lukas Nepelius Date: Wed, 9 Apr 2025 11:13:42 +0200 Subject: [PATCH 1/6] fix: add check for image dimensions during loading closes #14 --- src/viqa/utils/loading.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/viqa/utils/loading.py b/src/viqa/utils/loading.py index 7039693..c4ba6da 100644 --- a/src/viqa/utils/loading.py +++ b/src/viqa/utils/loading.py @@ -738,6 +738,14 @@ def load_data( "Input type not supported" ) # Raise exception if input type is not supported + if batch: + for img_num, img in enumerate(img_arr): + if len(img.shape) > 3 or len(img.shape) < 2: + raise ValueError(f"Number of dimensions of input image should be 2 or 3, got {len(img.shape)} for image number {img_num} in img_arr.") + elif not isinstance(img_arr, list): + if len(img_arr.shape) > 3 or len(img_arr.shape) < 2: + raise ValueError(f"Number of dimensions of input image should be 2 or 3, got {len(img_arr.shape)}.") + # Normalize data if normalize and data_range: # TODO: Deprecate in version 4.0.0 From c573452afeff69cfc67b46cc073b6acfe5a20ab7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 09:19:53 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/viqa/utils/loading.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/viqa/utils/loading.py b/src/viqa/utils/loading.py index c4ba6da..338fd48 100644 --- a/src/viqa/utils/loading.py +++ b/src/viqa/utils/loading.py @@ -741,10 +741,14 @@ def load_data( if batch: for img_num, img in enumerate(img_arr): if len(img.shape) > 3 or len(img.shape) < 2: - raise ValueError(f"Number of dimensions of input image should be 2 or 3, got {len(img.shape)} for image number {img_num} in img_arr.") - elif not isinstance(img_arr, list): + raise ValueError( + f"Number of dimensions of input image should be 2 or 3, got {len(img.shape)} for image number {img_num} in img_arr." + ) + elif not isinstance(img_arr, list): if len(img_arr.shape) > 3 or len(img_arr.shape) < 2: - raise ValueError(f"Number of dimensions of input image should be 2 or 3, got {len(img_arr.shape)}.") + raise ValueError( + f"Number of dimensions of input image should be 2 or 3, got {len(img_arr.shape)}." + ) # Normalize data if normalize and data_range: From c6766d240b16bc1cb2d3d2b8c9a223aaef4ebe2f Mon Sep 17 00:00:00 2001 From: lukasbehammer Date: Wed, 9 Apr 2025 11:24:41 +0200 Subject: [PATCH 3/6] refactor: refactor to comply with ruff --- src/viqa/utils/loading.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/viqa/utils/loading.py b/src/viqa/utils/loading.py index 338fd48..5c97013 100644 --- a/src/viqa/utils/loading.py +++ b/src/viqa/utils/loading.py @@ -739,15 +739,18 @@ def load_data( ) # Raise exception if input type is not supported if batch: + # TODO: Deprecate in version 4.0.0 for img_num, img in enumerate(img_arr): if len(img.shape) > 3 or len(img.shape) < 2: raise ValueError( - f"Number of dimensions of input image should be 2 or 3, got {len(img.shape)} for image number {img_num} in img_arr." + f"Number of dimensions of input image should be 2 or 3, got " + f"{len(img.shape)} for image number {img_num} in img_arr." ) elif not isinstance(img_arr, list): if len(img_arr.shape) > 3 or len(img_arr.shape) < 2: raise ValueError( - f"Number of dimensions of input image should be 2 or 3, got {len(img_arr.shape)}." + f"Number of dimensions of input image should be 2 or 3, got " + f"{len(img_arr.shape)}." ) # Normalize data From 44f7d2e2221b3fc60a4df82f03f51a194dba98b3 Mon Sep 17 00:00:00 2001 From: lukasbehammer Date: Wed, 9 Apr 2025 11:59:11 +0200 Subject: [PATCH 4/6] docs(load_data): add Exception to function docstring --- src/viqa/utils/loading.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/viqa/utils/loading.py b/src/viqa/utils/loading.py index 5c97013..f6f68af 100644 --- a/src/viqa/utils/loading.py +++ b/src/viqa/utils/loading.py @@ -655,7 +655,8 @@ def load_data( ------ ValueError If input type is not supported \n - If ``data_range=None`` and ``normalize=True`` + If ``data_range=None`` and ``normalize=True`` \n + If number of dimensions not 2 or 3 Warns ----- From c12d2614f9ded3dada55816752307a477e79b322 Mon Sep 17 00:00:00 2001 From: Lukas Nepelius Date: Tue, 22 Apr 2025 17:16:08 +0200 Subject: [PATCH 5/6] Device can be set for torch tensor images --- src/viqa/_metrics.py | 1 + src/viqa/fr_metrics/fsim.py | 8 ++++++++ src/viqa/fr_metrics/msssim.py | 8 ++++++++ src/viqa/fr_metrics/vif.py | 8 ++++++++ src/viqa/fr_metrics/vsi.py | 8 ++++++++ src/viqa/utils/misc.py | 14 +++++++------- 6 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/viqa/_metrics.py b/src/viqa/_metrics.py index 36dacfc..f0482e0 100644 --- a/src/viqa/_metrics.py +++ b/src/viqa/_metrics.py @@ -30,6 +30,7 @@ def __init__(self, data_range, normalize, **kwargs): "normalize": normalize, "chromatic": False, "roi": None, + "device": "cpu" **kwargs, } self.score_val = None diff --git a/src/viqa/fr_metrics/fsim.py b/src/viqa/fr_metrics/fsim.py index 7809db6..5202093 100644 --- a/src/viqa/fr_metrics/fsim.py +++ b/src/viqa/fr_metrics/fsim.py @@ -66,6 +66,9 @@ class FSIM(FullReferenceMetricsInterface): If True, the input images are expected to be RGB images and FSIMc is calculated. See [1]_. Passed to :py:func:`piq.fsim`. See the documentation under [2]_. + device : Union[str, torch.device], default 'cpu' + Determines the device if the image is a PyTorch tensor, + e.g. "cuda", "cpu", "cuda:0", ... Raises ------ @@ -197,6 +200,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[im_slice, :, :], img_m[im_slice, :, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, @@ -210,6 +214,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, im_slice, :], img_m[:, im_slice, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, @@ -223,6 +228,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, :, im_slice], img_m[:, :, im_slice], self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, @@ -246,6 +252,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, @@ -268,6 +275,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, diff --git a/src/viqa/fr_metrics/msssim.py b/src/viqa/fr_metrics/msssim.py index ae20403..6324a86 100644 --- a/src/viqa/fr_metrics/msssim.py +++ b/src/viqa/fr_metrics/msssim.py @@ -145,6 +145,9 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): Algorithm parameter, K2 (small constant, see [3]_). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + device : Union[str, torch.device], default 'cpu' + Determines the device if the image is a PyTorch tensor, + e.g. "cuda", "cpu", "cuda:0", ... .. seealso:: See :py:func:`.viqa.fr_metrics.ssim.structural_similarity` for more @@ -202,6 +205,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[im_slice, :, :], img_m[im_slice, :, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, @@ -214,6 +218,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, im_slice, :], img_m[:, im_slice, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, @@ -226,6 +231,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, :, im_slice], img_m[:, :, im_slice], self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, @@ -249,6 +255,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, @@ -270,6 +277,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, diff --git a/src/viqa/fr_metrics/vif.py b/src/viqa/fr_metrics/vif.py index c2e1b42..d7b3866 100644 --- a/src/viqa/fr_metrics/vif.py +++ b/src/viqa/fr_metrics/vif.py @@ -127,6 +127,9 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): HVS model parameter (variance of the visual noise). See [3]_. reduction : str, default='mean' Specifies the reduction type: 'none', 'mean' or 'sum'. + device : Union[str, torch.device], default 'cpu' + Determines the device if the image is a PyTorch tensor, + e.g. "cuda", "cpu", "cuda:0", ... Returns ------- @@ -176,6 +179,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[im_slice, :, :], img_m[im_slice, :, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, @@ -188,6 +192,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, im_slice, :], img_m[:, im_slice, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, @@ -200,6 +205,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, :, im_slice], img_m[:, :, im_slice], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, @@ -222,6 +228,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, @@ -243,6 +250,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, diff --git a/src/viqa/fr_metrics/vsi.py b/src/viqa/fr_metrics/vsi.py index c862272..5ca40ae 100644 --- a/src/viqa/fr_metrics/vsi.py +++ b/src/viqa/fr_metrics/vsi.py @@ -64,6 +64,9 @@ class VSI(FullReferenceMetricsInterface): ---------------- chromatic : bool, default False If True, the input images are expected to be RGB images. + device : Union[str, torch.device], default 'cpu' + Determines the device if the image is a PyTorch tensor, + e.g. "cuda", "cpu", "cuda:0", ... Raises ------ @@ -197,6 +200,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[im_slice, :, :], img_m[im_slice, :, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, @@ -209,6 +213,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, im_slice, :], img_m[:, im_slice, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, @@ -221,6 +226,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, :, im_slice], img_m[:, :, im_slice], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, @@ -243,6 +249,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, @@ -264,6 +271,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, diff --git a/src/viqa/utils/misc.py b/src/viqa/utils/misc.py index f897d97..5910b60 100644 --- a/src/viqa/utils/misc.py +++ b/src/viqa/utils/misc.py @@ -417,7 +417,7 @@ def _get_range(cols_rows): return res -def _check_chromatic(img_r, img_m, chromatic): +def _check_chromatic(img_r, img_m, chromatic, device): """Permute image based on dimensions and chromaticity.""" img_r = _to_float(img_r, np.float32) img_m = _to_float(img_m, np.float32) @@ -425,17 +425,17 @@ def _check_chromatic(img_r, img_m, chromatic): if chromatic is False: if img_r.ndim == 3: # 3D images - img_r_tensor = torch.tensor(img_r).unsqueeze(0).permute(3, 0, 1, 2) - img_m_tensor = torch.tensor(img_m).unsqueeze(0).permute(3, 0, 1, 2) + img_r_tensor = torch.tensor(img_r).unsqueeze(0).permute(3, 0, 1, 2).to(device) + img_m_tensor = torch.tensor(img_m).unsqueeze(0).permute(3, 0, 1, 2).to(device) elif img_r.ndim == 2: # 2D images - img_r_tensor = torch.tensor(img_r).unsqueeze(0).unsqueeze(0) - img_m_tensor = torch.tensor(img_m).unsqueeze(0).unsqueeze(0) + img_r_tensor = torch.tensor(img_r).unsqueeze(0).unsqueeze(0).to(device) + img_m_tensor = torch.tensor(img_m).unsqueeze(0).unsqueeze(0).to(device) else: raise ValueError("Image format not supported.") else: - img_r_tensor = torch.tensor(img_r).permute(2, 0, 1).unsqueeze(0) - img_m_tensor = torch.tensor(img_m).permute(2, 0, 1).unsqueeze(0) + img_r_tensor = torch.tensor(img_r).permute(2, 0, 1).unsqueeze(0).to(device) + img_m_tensor = torch.tensor(img_m).permute(2, 0, 1).unsqueeze(0).to(device) return img_r_tensor, img_m_tensor From 5025415c8b50566dccdf9ef3413b3ceef5fe118e Mon Sep 17 00:00:00 2001 From: Lukas Nepelius Date: Wed, 23 Apr 2025 14:18:49 +0200 Subject: [PATCH 6/6] refactor: ensure ruff rules --- src/viqa/_metrics.py | 3 +-- src/viqa/utils/misc.py | 8 ++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/viqa/_metrics.py b/src/viqa/_metrics.py index f0482e0..cba18ee 100644 --- a/src/viqa/_metrics.py +++ b/src/viqa/_metrics.py @@ -30,8 +30,7 @@ def __init__(self, data_range, normalize, **kwargs): "normalize": normalize, "chromatic": False, "roi": None, - "device": "cpu" - **kwargs, + "device": "cpu" ** kwargs, } self.score_val = None self._name = None diff --git a/src/viqa/utils/misc.py b/src/viqa/utils/misc.py index 5910b60..3c359da 100644 --- a/src/viqa/utils/misc.py +++ b/src/viqa/utils/misc.py @@ -425,8 +425,12 @@ def _check_chromatic(img_r, img_m, chromatic, device): if chromatic is False: if img_r.ndim == 3: # 3D images - img_r_tensor = torch.tensor(img_r).unsqueeze(0).permute(3, 0, 1, 2).to(device) - img_m_tensor = torch.tensor(img_m).unsqueeze(0).permute(3, 0, 1, 2).to(device) + img_r_tensor = ( + torch.tensor(img_r).unsqueeze(0).permute(3, 0, 1, 2).to(device) + ) + img_m_tensor = ( + torch.tensor(img_m).unsqueeze(0).permute(3, 0, 1, 2).to(device) + ) elif img_r.ndim == 2: # 2D images img_r_tensor = torch.tensor(img_r).unsqueeze(0).unsqueeze(0).to(device)