Skip to content

Commit fe4824b

Browse files
committed
use transformed_types in child classes not transform
1 parent 12a258b commit fe4824b

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torchvision/transforms/v2/_color.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.transforms.v2 import functional as F, Transform
88

99
from ._transform import _RandomApplyTransform
10-
from ._utils import query_chw
10+
from ._utils import is_cvcuda_tensor, query_chw
1111

1212

1313
class Grayscale(Transform):
@@ -22,6 +22,8 @@ class Grayscale(Transform):
2222

2323
_v1_transform_cls = _transforms.Grayscale
2424

25+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
26+
2527
def __init__(self, num_output_channels: int = 1):
2628
super().__init__()
2729
self.num_output_channels = num_output_channels
@@ -44,6 +46,8 @@ class RandomGrayscale(_RandomApplyTransform):
4446

4547
_v1_transform_cls = _transforms.RandomGrayscale
4648

49+
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)
50+
4751
def __init__(self, p: float = 0.1) -> None:
4852
super().__init__(p=p)
4953

@@ -62,6 +66,8 @@ class RGB(Transform):
6266
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
6367
"""
6468

69+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
70+
6571
def __init__(self):
6672
super().__init__()
6773

0 commit comments

Comments
 (0)