Skip to content

Commit 07341f9

Browse files
committed
consolidate gaussian_blur_image to use new validate_kernel_size_and_sigma
1 parent 937c31f commit 07341f9

File tree

1 file changed

+5
-31
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+5
-31
lines changed

torchvision/transforms/v2/functional/_misc.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional, Sequence, TYPE_CHECKING
2+
from typing import List, Optional, TYPE_CHECKING, Union
33

44
import PIL.Image
55
import torch
@@ -107,10 +107,10 @@ def _get_gaussian_kernel2d(
107107

108108

109109
def _validate_kernel_size_and_sigma(
110-
kernel_size: Sequence[int] | int,
111-
sigma: Sequence[float | int] | float | int | None = None,
110+
kernel_size: List[int] | int,
111+
sigma: Optional[Union[List[float], float, int]] = None,
112112
) -> tuple[list[int], list[float]]:
113-
# duplicated logic from gaussian_blur_image for use in gaussian_blur_cvcuda
113+
# TODO: consider deprecating integers from sigma on the future
114114
if isinstance(kernel_size, int):
115115
kernel_size = [kernel_size, kernel_size]
116116
elif len(kernel_size) != 2:
@@ -146,33 +146,7 @@ def _validate_kernel_size_and_sigma(
146146
def gaussian_blur_image(
147147
image: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None
148148
) -> torch.Tensor:
149-
# TODO: consider deprecating integers from sigma on the future
150-
if isinstance(kernel_size, int):
151-
kernel_size = [kernel_size, kernel_size]
152-
elif len(kernel_size) != 2:
153-
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
154-
for ksize in kernel_size:
155-
if ksize % 2 == 0 or ksize < 0:
156-
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
157-
158-
if sigma is None:
159-
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
160-
else:
161-
if isinstance(sigma, (list, tuple)):
162-
length = len(sigma)
163-
if length == 1:
164-
s = sigma[0]
165-
sigma = [s, s]
166-
elif length != 2:
167-
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
168-
elif isinstance(sigma, (int, float)):
169-
s = float(sigma)
170-
sigma = [s, s]
171-
else:
172-
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
173-
for s in sigma:
174-
if s <= 0.0:
175-
raise ValueError(f"sigma should have positive values. Got {sigma}")
149+
kernel_size, sigma = _validate_kernel_size_and_sigma(kernel_size, sigma)
176150

177151
if image.numel() == 0:
178152
return image

0 commit comments

Comments
 (0)