diff --git a/src/viqa/_metrics.py b/src/viqa/_metrics.py index 36dacfca..cba18ee2 100644 --- a/src/viqa/_metrics.py +++ b/src/viqa/_metrics.py @@ -30,7 +30,7 @@ def __init__(self, data_range, normalize, **kwargs): "normalize": normalize, "chromatic": False, "roi": None, - **kwargs, + "device": "cpu" ** kwargs, } self.score_val = None self._name = None diff --git a/src/viqa/fr_metrics/fsim.py b/src/viqa/fr_metrics/fsim.py index 7809db6c..52020936 100644 --- a/src/viqa/fr_metrics/fsim.py +++ b/src/viqa/fr_metrics/fsim.py @@ -66,6 +66,9 @@ class FSIM(FullReferenceMetricsInterface): If True, the input images are expected to be RGB images and FSIMc is calculated. See [1]_. Passed to :py:func:`piq.fsim`. See the documentation under [2]_. + device : Union[str, torch.device], default 'cpu' + Determines the device if the image is a PyTorch tensor, + e.g. "cuda", "cpu", "cuda:0", ... Raises ------ @@ -197,6 +200,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[im_slice, :, :], img_m[im_slice, :, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, @@ -210,6 +214,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, im_slice, :], img_m[:, im_slice, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, @@ -223,6 +228,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, :, im_slice], img_m[:, :, im_slice], self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, @@ -246,6 +252,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, @@ -268,6 +275,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = fsim( img_r_tensor, diff --git a/src/viqa/fr_metrics/msssim.py b/src/viqa/fr_metrics/msssim.py index ae20403c..6324a86d 100644 --- a/src/viqa/fr_metrics/msssim.py +++ b/src/viqa/fr_metrics/msssim.py @@ -145,6 +145,9 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): Algorithm parameter, K2 (small constant, see [3]_). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + device : Union[str, torch.device], default 'cpu' + Determines the device if the image is a PyTorch tensor, + e.g. "cuda", "cpu", "cuda:0", ... .. seealso:: See :py:func:`.viqa.fr_metrics.ssim.structural_similarity` for more @@ -202,6 +205,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[im_slice, :, :], img_m[im_slice, :, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, @@ -214,6 +218,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, im_slice, :], img_m[:, im_slice, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, @@ -226,6 +231,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, :, im_slice], img_m[:, :, im_slice], self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, @@ -249,6 +255,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, @@ -270,6 +277,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = multi_scale_ssim( img_r_tensor, diff --git a/src/viqa/fr_metrics/vif.py b/src/viqa/fr_metrics/vif.py index c2e1b42d..d7b38667 100644 --- a/src/viqa/fr_metrics/vif.py +++ b/src/viqa/fr_metrics/vif.py @@ -127,6 +127,9 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): HVS model parameter (variance of the visual noise). See [3]_. reduction : str, default='mean' Specifies the reduction type: 'none', 'mean' or 'sum'. + device : Union[str, torch.device], default 'cpu' + Determines the device if the image is a PyTorch tensor, + e.g. "cuda", "cpu", "cuda:0", ... Returns ------- @@ -176,6 +179,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[im_slice, :, :], img_m[im_slice, :, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, @@ -188,6 +192,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, im_slice, :], img_m[:, im_slice, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, @@ -200,6 +205,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, :, im_slice], img_m[:, :, im_slice], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, @@ -222,6 +228,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, @@ -243,6 +250,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = vif_p( img_r_tensor, diff --git a/src/viqa/fr_metrics/vsi.py b/src/viqa/fr_metrics/vsi.py index c8622728..5ca40ae4 100644 --- a/src/viqa/fr_metrics/vsi.py +++ b/src/viqa/fr_metrics/vsi.py @@ -64,6 +64,9 @@ class VSI(FullReferenceMetricsInterface): ---------------- chromatic : bool, default False If True, the input images are expected to be RGB images. + device : Union[str, torch.device], default 'cpu' + Determines the device if the image is a PyTorch tensor, + e.g. "cuda", "cpu", "cuda:0", ... Raises ------ @@ -197,6 +200,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[im_slice, :, :], img_m[im_slice, :, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, @@ -209,6 +213,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, im_slice, :], img_m[:, im_slice, :], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, @@ -221,6 +226,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r[:, :, im_slice], img_m[:, :, im_slice], self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, @@ -243,6 +249,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, @@ -264,6 +271,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs): img_r, img_m, self.parameters["chromatic"], + self.parameters["device"], ) score_val = vsi( img_r_tensor, diff --git a/src/viqa/utils/misc.py b/src/viqa/utils/misc.py index f897d972..3c359da4 100644 --- a/src/viqa/utils/misc.py +++ b/src/viqa/utils/misc.py @@ -417,7 +417,7 @@ def _get_range(cols_rows): return res -def _check_chromatic(img_r, img_m, chromatic): +def _check_chromatic(img_r, img_m, chromatic, device): """Permute image based on dimensions and chromaticity.""" img_r = _to_float(img_r, np.float32) img_m = _to_float(img_m, np.float32) @@ -425,17 +425,21 @@ def _check_chromatic(img_r, img_m, chromatic): if chromatic is False: if img_r.ndim == 3: # 3D images - img_r_tensor = torch.tensor(img_r).unsqueeze(0).permute(3, 0, 1, 2) - img_m_tensor = torch.tensor(img_m).unsqueeze(0).permute(3, 0, 1, 2) + img_r_tensor = ( + torch.tensor(img_r).unsqueeze(0).permute(3, 0, 1, 2).to(device) + ) + img_m_tensor = ( + torch.tensor(img_m).unsqueeze(0).permute(3, 0, 1, 2).to(device) + ) elif img_r.ndim == 2: # 2D images - img_r_tensor = torch.tensor(img_r).unsqueeze(0).unsqueeze(0) - img_m_tensor = torch.tensor(img_m).unsqueeze(0).unsqueeze(0) + img_r_tensor = torch.tensor(img_r).unsqueeze(0).unsqueeze(0).to(device) + img_m_tensor = torch.tensor(img_m).unsqueeze(0).unsqueeze(0).to(device) else: raise ValueError("Image format not supported.") else: - img_r_tensor = torch.tensor(img_r).permute(2, 0, 1).unsqueeze(0) - img_m_tensor = torch.tensor(img_m).permute(2, 0, 1).unsqueeze(0) + img_r_tensor = torch.tensor(img_r).permute(2, 0, 1).unsqueeze(0).to(device) + img_m_tensor = torch.tensor(img_m).permute(2, 0, 1).unsqueeze(0).to(device) return img_r_tensor, img_m_tensor