|
1 | 1 | import math |
2 | | -from typing import Optional, Sequence, TYPE_CHECKING |
| 2 | +from typing import List, Optional, TYPE_CHECKING, Union |
3 | 3 |
|
4 | 4 | import PIL.Image |
5 | 5 | import torch |
@@ -107,10 +107,10 @@ def _get_gaussian_kernel2d( |
107 | 107 |
|
108 | 108 |
|
109 | 109 | def _validate_kernel_size_and_sigma( |
110 | | - kernel_size: Sequence[int] | int, |
111 | | - sigma: Sequence[float | int] | float | int | None = None, |
| 110 | + kernel_size: List[int] | int, |
| 111 | + sigma: Optional[Union[List[float], float, int]] = None, |
112 | 112 | ) -> tuple[list[int], list[float]]: |
113 | | - # duplicated logic from gaussian_blur_image for use in gaussian_blur_cvcuda |
| 113 | + # TODO: consider deprecating integers from sigma on the future |
114 | 114 | if isinstance(kernel_size, int): |
115 | 115 | kernel_size = [kernel_size, kernel_size] |
116 | 116 | elif len(kernel_size) != 2: |
@@ -146,33 +146,7 @@ def _validate_kernel_size_and_sigma( |
146 | 146 | def gaussian_blur_image( |
147 | 147 | image: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None |
148 | 148 | ) -> torch.Tensor: |
149 | | - # TODO: consider deprecating integers from sigma on the future |
150 | | - if isinstance(kernel_size, int): |
151 | | - kernel_size = [kernel_size, kernel_size] |
152 | | - elif len(kernel_size) != 2: |
153 | | - raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") |
154 | | - for ksize in kernel_size: |
155 | | - if ksize % 2 == 0 or ksize < 0: |
156 | | - raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}") |
157 | | - |
158 | | - if sigma is None: |
159 | | - sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] |
160 | | - else: |
161 | | - if isinstance(sigma, (list, tuple)): |
162 | | - length = len(sigma) |
163 | | - if length == 1: |
164 | | - s = sigma[0] |
165 | | - sigma = [s, s] |
166 | | - elif length != 2: |
167 | | - raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}") |
168 | | - elif isinstance(sigma, (int, float)): |
169 | | - s = float(sigma) |
170 | | - sigma = [s, s] |
171 | | - else: |
172 | | - raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") |
173 | | - for s in sigma: |
174 | | - if s <= 0.0: |
175 | | - raise ValueError(f"sigma should have positive values. Got {sigma}") |
| 149 | + kernel_size, sigma = _validate_kernel_size_and_sigma(kernel_size, sigma) |
176 | 150 |
|
177 | 151 | if image.numel() == 0: |
178 | 152 | return image |
|
0 commit comments