2020from torch .testing ._comparison import BooleanPair , NonePair , not_close_error_metas , NumberPair , TensorLikePair
2121from torchvision import io , tv_tensors
2222from torchvision .transforms ._functional_tensor import _max_value as get_max_value
23- from torchvision .transforms .v2 .functional import to_cvcuda_tensor , to_image , to_pil_image
23+ from torchvision .transforms .v2 .functional import cvcuda_to_tensor , to_cvcuda_tensor , to_image , to_pil_image
24+ from torchvision .transforms .v2 .functional ._utils import _import_cvcuda , _is_cvcuda_available
2425from torchvision .utils import _Image_fromarray
2526
2627
2728IN_OSS_CI = any (os .getenv (var ) == "true" for var in ["CIRCLECI" , "GITHUB_ACTIONS" ])
2829IN_RE_WORKER = os .environ .get ("INSIDE_RE_WORKER" ) is not None
2930IN_FBCODE = os .environ .get ("IN_FBCODE_TORCHVISION" ) == "1"
31+ CVCUDA_AVAILABLE = _is_cvcuda_available ()
3032CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3133MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3234OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@@ -275,6 +277,17 @@ def combinations_grid(**kwargs):
275277 return [dict (zip (kwargs .keys (), values )) for values in itertools .product (* kwargs .values ())]
276278
277279
280+ def cvcuda_to_pil_compatible_tensor (tensor ):
281+ tensor = cvcuda_to_tensor (tensor )
282+ if tensor .ndim != 4 :
283+ raise ValueError (f"CV-CUDA Tensor should be 4 dimensional. Got { tensor .ndim } dimensions." )
284+ if tensor .shape [0 ] != 1 :
285+ raise ValueError (
286+ f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got { tensor .shape [0 ]} ."
287+ )
288+ return tensor .squeeze (0 ).cpu ()
289+
290+
278291class ImagePair (TensorLikePair ):
279292 def __init__ (
280293 self ,
@@ -287,6 +300,16 @@ def __init__(
287300 if all (isinstance (input , PIL .Image .Image ) for input in [actual , expected ]):
288301 actual , expected = (to_image (input ) for input in [actual , expected ])
289302
303+ if CVCUDA_AVAILABLE and all (isinstance (input , _import_cvcuda ().Tensor ) for input in [actual , expected ]):
304+ actual , expected = (cvcuda_to_tensor (input ) for input in [actual , expected ])
305+
306+ if CVCUDA_AVAILABLE and isinstance (actual , _import_cvcuda ().Tensor ) and isinstance (expected , PIL .Image .Image ):
307+ actual = cvcuda_to_pil_compatible_tensor (actual )
308+ expected = to_image (expected )
309+
310+ if CVCUDA_AVAILABLE and isinstance (actual , _import_cvcuda ().Tensor ):
311+ actual = cvcuda_to_pil_compatible_tensor (actual )
312+
290313 super ().__init__ (actual , expected , ** other_parameters )
291314 self .mae = mae
292315
0 commit comments