Skip to content

Commit b4e7e45

Browse files
committed
use transformed_types in child classes not transform
1 parent c4a2bb8 commit b4e7e45

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

torchvision/transforms/v2/_color.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.transforms.v2 import functional as F, Transform
88

99
from ._transform import _RandomApplyTransform
10-
from ._utils import query_chw
10+
from ._utils import is_cvcuda_tensor, query_chw
1111

1212

1313
class Grayscale(Transform):
@@ -22,6 +22,8 @@ class Grayscale(Transform):
2222

2323
_v1_transform_cls = _transforms.Grayscale
2424

25+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
26+
2527
def __init__(self, num_output_channels: int = 1):
2628
super().__init__()
2729
self.num_output_channels = num_output_channels
@@ -44,6 +46,8 @@ class RandomGrayscale(_RandomApplyTransform):
4446

4547
_v1_transform_cls = _transforms.RandomGrayscale
4648

49+
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)
50+
4751
def __init__(self, p: float = 0.1) -> None:
4852
super().__init__(p=p)
4953

@@ -62,6 +66,8 @@ class RGB(Transform):
6266
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
6367
"""
6468

69+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
70+
6571
def __init__(self):
6672
super().__init__()
6773

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

14-
from .functional._utils import _get_kernel, is_cvcuda_tensor
14+
from .functional._utils import _get_kernel
1515

1616

1717
class Transform(nn.Module):
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

0 commit comments

Comments
 (0)