Skip to content

Commit 3bcc517

Browse files
committed
update gaussian blur with main standards
1 parent 5ce83b1 commit 3bcc517

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

test/test_transforms_v2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28-
cvcuda_to_pil_compatible_tensor,
2928
freeze_rng_state,
3029
ignore_jit_no_profile_information_warning,
3130
make_bounding_boxes,
@@ -3958,7 +3957,7 @@ def test_functional(self, make_input):
39583957
(F.gaussian_blur_image, tv_tensors.Image),
39593958
(F.gaussian_blur_video, tv_tensors.Video),
39603959
pytest.param(
3961-
F._misc._gaussian_blur_cvcuda,
3960+
F._misc._gaussian_blur_image_cvcuda,
39623961
"cvcuda.Tensor",
39633962
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
39643963
),

torchvision/transforms/v2/_misc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torchvision import transforms as _transforms, tv_tensors
1111
from torchvision.transforms.v2 import functional as F, Transform
12+
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor
1213

1314
from ._utils import (
1415
_parse_labels_getter,
@@ -20,6 +21,8 @@
2021
is_pure_tensor,
2122
)
2223

24+
CVCUDA_AVAILABLE = _is_cvcuda_available()
25+
2326

2427
# TODO: do we want/need to expose this?
2528
class Identity(Transform):
@@ -192,6 +195,9 @@ class GaussianBlur(Transform):
192195

193196
_v1_transform_cls = _transforms.GaussianBlur
194197

198+
if CVCUDA_AVAILABLE:
199+
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)
200+
195201
def __init__(
196202
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
197203
) -> None:

torchvision/transforms/v2/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from torchvision._utils import sequence_to_str
1616

1717
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
18-
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
19-
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
18+
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
19+
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor
2020

2121

2222
def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
@@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
182182
chws = {
183183
tuple(get_dimensions(inpt))
184184
for inpt in flat_inputs
185-
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
185+
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
186186
}
187187
if not chws:
188188
raise TypeError("No image or video was found in the sample")
@@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
207207
tv_tensors.Mask,
208208
tv_tensors.BoundingBoxes,
209209
tv_tensors.KeyPoints,
210-
is_cvcuda_tensor,
210+
_is_cvcuda_tensor,
211211
),
212212
)
213213
}

torchvision/transforms/v2/functional/_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def gaussian_blur_video(
197197
return gaussian_blur_image(video, kernel_size, sigma)
198198

199199

200-
def _gaussian_blur_cvcuda(
200+
def _gaussian_blur_image_cvcuda(
201201
image: "cvcuda.Tensor", kernel_size: list[int], sigma: Optional[list[float]] = None
202202
) -> "cvcuda.Tensor":
203203
cvcuda = _import_cvcuda()
@@ -213,7 +213,7 @@ def _gaussian_blur_cvcuda(
213213

214214

215215
if CVCUDA_AVAILABLE:
216-
_register_kernel_internal(gaussian_blur, _import_cvcuda().Tensor)(_gaussian_blur_cvcuda)
216+
_register_kernel_internal(gaussian_blur, _import_cvcuda().Tensor)(_gaussian_blur_image_cvcuda)
217217

218218

219219
def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:

0 commit comments

Comments
 (0)