Skip to content

Commit 9cd7582

Browse files
committed
resolve more review comments
1 parent 5df3a7d commit 9cd7582

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

test/test_transforms_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4085,7 +4085,7 @@ def test_functional_image_correctness(self, dimensions, kernel_size, sigma, dtyp
40854085
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
40864086
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
40874087
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4088-
def test_functional_cvcuda_parity(self, dimensions, kernel_size, sigma, color_space, batch_dims, dtype):
4088+
def test_functional_cvcuda_correctness(self, dimensions, kernel_size, sigma, color_space, batch_dims, dtype):
40894089
height, width = dimensions
40904090

40914091
image_tensor = make_image(
@@ -4098,6 +4098,9 @@ def test_functional_cvcuda_parity(self, dimensions, kernel_size, sigma, color_sp
40984098
actual_torch = F.cvcuda_to_tensor(actual)
40994099

41004100
if dtype.is_floating_point:
4101+
# floating point precision differences between torch and cvcuda are present
4102+
# observed greatest absolute error is 0.3
4103+
# most likely stemming from different implementation
41014104
torch.testing.assert_close(actual_torch, expected, rtol=0, atol=0.3)
41024105
else:
41034106
# uint8/16 gaussians can differ by up to max-value, most likely an overflow issue

torchvision/transforms/v2/functional/_misc.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import List, Optional, TYPE_CHECKING, Union
2+
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
33

44
import PIL.Image
55
import torch
@@ -109,7 +109,7 @@ def _get_gaussian_kernel2d(
109109
def _validate_kernel_size_and_sigma(
110110
kernel_size: List[int] | int,
111111
sigma: Optional[Union[List[float], float, int]] = None,
112-
) -> tuple[list[int], list[float]]:
112+
) -> Tuple[List[int], List[float]]:
113113
# TODO: consider deprecating integers from sigma on the future
114114
if isinstance(kernel_size, int):
115115
kernel_size = [kernel_size, kernel_size]
@@ -213,9 +213,7 @@ def _gaussian_blur_cvcuda(
213213

214214

215215
if CVCUDA_AVAILABLE:
216-
_gaussian_blur_cvcuda_registered = _register_kernel_internal(gaussian_blur, _import_cvcuda().Tensor)(
217-
_gaussian_blur_cvcuda
218-
)
216+
_register_kernel_internal(gaussian_blur, _import_cvcuda().Tensor)(_gaussian_blur_cvcuda)
219217

220218

221219
def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:

0 commit comments

Comments
 (0)