Skip to content

Commit 4f870bc

Browse files
committed
move transformed tpyes to equalie transform class
1 parent feb36a2 commit 4f870bc

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

torchvision/transforms/v2/_color.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ._transform import _RandomApplyTransform
1010
from ._utils import query_chw
11+
from .functional._utils import is_cvcuda_tensor
1112

1213

1314
class Grayscale(Transform):
@@ -265,6 +266,8 @@ class RandomEqualize(_RandomApplyTransform):
265266

266267
_v1_transform_cls = _transforms.RandomEqualize
267268

269+
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)
270+
268271
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
269272
return self._call_kernel(F.equalize, inpt)
270273

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

14-
from .functional._utils import _get_kernel, is_cvcuda_tensor
14+
from .functional._utils import _get_kernel
1515

1616

1717
class Transform(nn.Module):
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

0 commit comments

Comments
 (0)